state.py 54 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480
  1. from __future__ import annotations
  2. import inspect
  3. import logging
  4. import typing
  5. import warnings
  6. from collections import defaultdict
  7. from collections.abc import Awaitable, Callable, Hashable, Sequence
  8. from functools import partial
  9. from inspect import isclass, isfunction, ismethod, signature
  10. from types import FunctionType
  11. from types import NoneType as NoneType
  12. from typing import (
  13. Any,
  14. Generic,
  15. Literal,
  16. Union,
  17. cast,
  18. get_args,
  19. get_origin,
  20. get_type_hints,
  21. overload,
  22. )
  23. from langchain_core.runnables import Runnable, RunnableConfig
  24. from langgraph.cache.base import BaseCache
  25. from langgraph.checkpoint.base import Checkpoint
  26. from langgraph.store.base import BaseStore
  27. from pydantic import BaseModel, TypeAdapter
  28. from typing_extensions import NotRequired, Required, Self, Unpack, is_typeddict
  29. from langgraph._internal._constants import (
  30. INTERRUPT,
  31. NS_END,
  32. NS_SEP,
  33. TASKS,
  34. )
  35. from langgraph._internal._fields import (
  36. get_cached_annotated_keys,
  37. get_field_default,
  38. get_update_as_tuples,
  39. )
  40. from langgraph._internal._pydantic import create_model
  41. from langgraph._internal._runnable import coerce_to_runnable
  42. from langgraph._internal._typing import EMPTY_SEQ, MISSING, DeprecatedKwargs
  43. from langgraph.channels.base import BaseChannel
  44. from langgraph.channels.binop import BinaryOperatorAggregate
  45. from langgraph.channels.ephemeral_value import EphemeralValue
  46. from langgraph.channels.last_value import LastValue, LastValueAfterFinish
  47. from langgraph.channels.named_barrier_value import (
  48. NamedBarrierValue,
  49. NamedBarrierValueAfterFinish,
  50. )
  51. from langgraph.constants import END, START, TAG_HIDDEN
  52. from langgraph.errors import (
  53. ErrorCode,
  54. InvalidUpdateError,
  55. ParentCommand,
  56. create_error_message,
  57. )
  58. from langgraph.graph._branch import BranchSpec
  59. from langgraph.graph._node import StateNode, StateNodeSpec
  60. from langgraph.managed.base import (
  61. ManagedValueSpec,
  62. is_managed_value,
  63. )
  64. from langgraph.pregel import Pregel
  65. from langgraph.pregel._read import ChannelRead, PregelNode
  66. from langgraph.pregel._write import (
  67. ChannelWrite,
  68. ChannelWriteEntry,
  69. ChannelWriteTupleEntry,
  70. )
  71. from langgraph.types import (
  72. All,
  73. CachePolicy,
  74. Checkpointer,
  75. Command,
  76. RetryPolicy,
  77. Send,
  78. )
  79. from langgraph.typing import ContextT, InputT, NodeInputT, OutputT, StateT
  80. from langgraph.warnings import LangGraphDeprecatedSinceV05, LangGraphDeprecatedSinceV10
  81. __all__ = ("StateGraph", "CompiledStateGraph")
  82. logger = logging.getLogger(__name__)
  83. _CHANNEL_BRANCH_TO = "branch:to:{}"
  84. def _warn_invalid_state_schema(schema: type[Any] | Any) -> None:
  85. if isinstance(schema, type):
  86. return
  87. if typing.get_args(schema):
  88. return
  89. warnings.warn(
  90. f"Invalid state_schema: {schema}. Expected a type or Annotated[type, reducer]. "
  91. "Please provide a valid schema to ensure correct updates.\n"
  92. " See: https://langchain-ai.github.io/langgraph/reference/graphs/#stategraph"
  93. )
  94. def _get_node_name(node: StateNode[Any, ContextT]) -> str:
  95. try:
  96. return getattr(node, "__name__", node.__class__.__name__)
  97. except AttributeError:
  98. raise TypeError(f"Unsupported node type: {type(node)}")
  99. class StateGraph(Generic[StateT, ContextT, InputT, OutputT]):
  100. """A graph whose nodes communicate by reading and writing to a shared state.
  101. The signature of each node is `State -> Partial<State>`.
  102. Each state key can optionally be annotated with a reducer function that
  103. will be used to aggregate the values of that key received from multiple nodes.
  104. The signature of a reducer function is `(Value, Value) -> Value`.
  105. !!! warning
  106. `StateGraph` is a builder class and cannot be used directly for execution.
  107. You must first call `.compile()` to create an executable graph that supports
  108. methods like `invoke()`, `stream()`, `astream()`, and `ainvoke()`. See the
  109. `CompiledStateGraph` documentation for more details.
  110. Args:
  111. state_schema: The schema class that defines the state.
  112. context_schema: The schema class that defines the runtime context.
  113. Use this to expose immutable context data to your nodes, like `user_id`, `db_conn`, etc.
  114. input_schema: The schema class that defines the input to the graph.
  115. output_schema: The schema class that defines the output from the graph.
  116. !!! warning "`config_schema` Deprecated"
  117. The `config_schema` parameter is deprecated in v0.6.0 and support will be removed in v2.0.0.
  118. Please use `context_schema` instead to specify the schema for run-scoped context.
  119. Example:
  120. ```python
  121. from langchain_core.runnables import RunnableConfig
  122. from typing_extensions import Annotated, TypedDict
  123. from langgraph.checkpoint.memory import InMemorySaver
  124. from langgraph.graph import StateGraph
  125. from langgraph.runtime import Runtime
  126. def reducer(a: list, b: int | None) -> list:
  127. if b is not None:
  128. return a + [b]
  129. return a
  130. class State(TypedDict):
  131. x: Annotated[list, reducer]
  132. class Context(TypedDict):
  133. r: float
  134. graph = StateGraph(state_schema=State, context_schema=Context)
  135. def node(state: State, runtime: Runtime[Context]) -> dict:
  136. r = runtime.context.get("r", 1.0)
  137. x = state["x"][-1]
  138. next_value = x * r * (1 - x)
  139. return {"x": next_value}
  140. graph.add_node("A", node)
  141. graph.set_entry_point("A")
  142. graph.set_finish_point("A")
  143. compiled = graph.compile()
  144. step1 = compiled.invoke({"x": 0.5}, context={"r": 3.0})
  145. # {'x': [0.5, 0.75]}
  146. ```
  147. """
  148. edges: set[tuple[str, str]]
  149. nodes: dict[str, StateNodeSpec[Any, ContextT]]
  150. branches: defaultdict[str, dict[str, BranchSpec]]
  151. channels: dict[str, BaseChannel]
  152. managed: dict[str, ManagedValueSpec]
  153. schemas: dict[type[Any], dict[str, BaseChannel | ManagedValueSpec]]
  154. waiting_edges: set[tuple[tuple[str, ...], str]]
  155. compiled: bool
  156. state_schema: type[StateT]
  157. context_schema: type[ContextT] | None
  158. input_schema: type[InputT]
  159. output_schema: type[OutputT]
  160. def __init__(
  161. self,
  162. state_schema: type[StateT],
  163. context_schema: type[ContextT] | None = None,
  164. *,
  165. input_schema: type[InputT] | None = None,
  166. output_schema: type[OutputT] | None = None,
  167. **kwargs: Unpack[DeprecatedKwargs],
  168. ) -> None:
  169. if (config_schema := kwargs.get("config_schema", MISSING)) is not MISSING:
  170. warnings.warn(
  171. "`config_schema` is deprecated and will be removed. Please use `context_schema` instead.",
  172. category=LangGraphDeprecatedSinceV10,
  173. stacklevel=2,
  174. )
  175. if context_schema is None:
  176. context_schema = cast(type[ContextT], config_schema)
  177. if (input_ := kwargs.get("input", MISSING)) is not MISSING:
  178. warnings.warn(
  179. "`input` is deprecated and will be removed. Please use `input_schema` instead.",
  180. category=LangGraphDeprecatedSinceV05,
  181. stacklevel=2,
  182. )
  183. if input_schema is None:
  184. input_schema = cast(type[InputT], input_)
  185. if (output := kwargs.get("output", MISSING)) is not MISSING:
  186. warnings.warn(
  187. "`output` is deprecated and will be removed. Please use `output_schema` instead.",
  188. category=LangGraphDeprecatedSinceV05,
  189. stacklevel=2,
  190. )
  191. if output_schema is None:
  192. output_schema = cast(type[OutputT], output)
  193. self.nodes = {}
  194. self.edges = set()
  195. self.branches = defaultdict(dict)
  196. self.schemas = {}
  197. self.channels = {}
  198. self.managed = {}
  199. self.compiled = False
  200. self.waiting_edges = set()
  201. self.state_schema = state_schema
  202. self.input_schema = cast(type[InputT], input_schema or state_schema)
  203. self.output_schema = cast(type[OutputT], output_schema or state_schema)
  204. self.context_schema = context_schema
  205. self._add_schema(self.state_schema)
  206. self._add_schema(self.input_schema, allow_managed=False)
  207. self._add_schema(self.output_schema, allow_managed=False)
  208. @property
  209. def _all_edges(self) -> set[tuple[str, str]]:
  210. return self.edges | {
  211. (start, end) for starts, end in self.waiting_edges for start in starts
  212. }
  213. def _add_schema(self, schema: type[Any], /, allow_managed: bool = True) -> None:
  214. if schema not in self.schemas:
  215. _warn_invalid_state_schema(schema)
  216. channels, managed, type_hints = _get_channels(schema)
  217. if managed and not allow_managed:
  218. names = ", ".join(managed)
  219. schema_name = getattr(schema, "__name__", "")
  220. raise ValueError(
  221. f"Invalid managed channels detected in {schema_name}: {names}."
  222. " Managed channels are not permitted in Input/Output schema."
  223. )
  224. self.schemas[schema] = {**channels, **managed}
  225. for key, channel in channels.items():
  226. if key in self.channels:
  227. if self.channels[key] != channel:
  228. if isinstance(channel, LastValue):
  229. pass
  230. else:
  231. raise ValueError(
  232. f"Channel '{key}' already exists with a different type"
  233. )
  234. else:
  235. self.channels[key] = channel
  236. for key, managed in managed.items():
  237. if key in self.managed:
  238. if self.managed[key] != managed:
  239. raise ValueError(
  240. f"Managed value '{key}' already exists with a different type"
  241. )
  242. else:
  243. self.managed[key] = managed
  244. @overload
  245. def add_node(
  246. self,
  247. node: StateNode[NodeInputT, ContextT],
  248. *,
  249. defer: bool = False,
  250. metadata: dict[str, Any] | None = None,
  251. input_schema: None = None,
  252. retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
  253. cache_policy: CachePolicy | None = None,
  254. destinations: dict[str, str] | tuple[str, ...] | None = None,
  255. **kwargs: Unpack[DeprecatedKwargs],
  256. ) -> Self:
  257. """Add a new node to the `StateGraph`, input schema is inferred as the state schema.
  258. Will take the name of the function/runnable as the node name.
  259. """
  260. ...
  261. @overload
  262. def add_node(
  263. self,
  264. node: StateNode[NodeInputT, ContextT],
  265. *,
  266. defer: bool = False,
  267. metadata: dict[str, Any] | None = None,
  268. input_schema: type[NodeInputT],
  269. retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
  270. cache_policy: CachePolicy | None = None,
  271. destinations: dict[str, str] | tuple[str, ...] | None = None,
  272. **kwargs: Unpack[DeprecatedKwargs],
  273. ) -> Self:
  274. """Add a new node to the `StateGraph`, input schema is specified.
  275. Will take the name of the function/runnable as the node name.
  276. """
  277. ...
  278. @overload
  279. def add_node(
  280. self,
  281. node: str,
  282. action: StateNode[NodeInputT, ContextT],
  283. *,
  284. defer: bool = False,
  285. metadata: dict[str, Any] | None = None,
  286. input_schema: None = None,
  287. retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
  288. cache_policy: CachePolicy | None = None,
  289. destinations: dict[str, str] | tuple[str, ...] | None = None,
  290. **kwargs: Unpack[DeprecatedKwargs],
  291. ) -> Self:
  292. """Add a new node to the `StateGraph`, input schema is inferred as the state schema."""
  293. ...
  294. @overload
  295. def add_node(
  296. self,
  297. node: str | StateNode[NodeInputT, ContextT],
  298. action: StateNode[NodeInputT, ContextT] | None = None,
  299. *,
  300. defer: bool = False,
  301. metadata: dict[str, Any] | None = None,
  302. input_schema: type[NodeInputT],
  303. retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
  304. cache_policy: CachePolicy | None = None,
  305. destinations: dict[str, str] | tuple[str, ...] | None = None,
  306. **kwargs: Unpack[DeprecatedKwargs],
  307. ) -> Self:
  308. """Add a new node to the `StateGraph`, input schema is specified."""
  309. ...
  310. def add_node(
  311. self,
  312. node: str | StateNode[NodeInputT, ContextT],
  313. action: StateNode[NodeInputT, ContextT] | None = None,
  314. *,
  315. defer: bool = False,
  316. metadata: dict[str, Any] | None = None,
  317. input_schema: type[NodeInputT] | None = None,
  318. retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
  319. cache_policy: CachePolicy | None = None,
  320. destinations: dict[str, str] | tuple[str, ...] | None = None,
  321. **kwargs: Unpack[DeprecatedKwargs],
  322. ) -> Self:
  323. """Add a new node to the `StateGraph`.
  324. Args:
  325. node: The function or runnable this node will run.
  326. If a string is provided, it will be used as the node name, and action will be used as the function or runnable.
  327. action: The action associated with the node.
  328. Will be used as the node function or runnable if `node` is a string (node name).
  329. defer: Whether to defer the execution of the node until the run is about to end.
  330. metadata: The metadata associated with the node.
  331. input_schema: The input schema for the node. (Default: the graph's state schema)
  332. retry_policy: The retry policy for the node.
  333. If a sequence is provided, the first matching policy will be applied.
  334. cache_policy: The cache policy for the node.
  335. destinations: Destinations that indicate where a node can route to.
  336. Useful for edgeless graphs with nodes that return `Command` objects.
  337. If a `dict` is provided, the keys will be used as the target node names and the values will be used as the labels for the edges.
  338. If a `tuple` is provided, the values will be used as the target node names.
  339. !!! note
  340. This is only used for graph rendering and doesn't have any effect on the graph execution.
  341. Example:
  342. ```python
  343. from typing_extensions import TypedDict
  344. from langchain_core.runnables import RunnableConfig
  345. from langgraph.graph import START, StateGraph
  346. class State(TypedDict):
  347. x: int
  348. def my_node(state: State, config: RunnableConfig) -> State:
  349. return {"x": state["x"] + 1}
  350. builder = StateGraph(State)
  351. builder.add_node(my_node) # node name will be 'my_node'
  352. builder.add_edge(START, "my_node")
  353. graph = builder.compile()
  354. graph.invoke({"x": 1})
  355. # {'x': 2}
  356. ```
  357. Example: Customize the name:
  358. ```python
  359. builder = StateGraph(State)
  360. builder.add_node("my_fair_node", my_node)
  361. builder.add_edge(START, "my_fair_node")
  362. graph = builder.compile()
  363. graph.invoke({"x": 1})
  364. # {'x': 2}
  365. ```
  366. Returns:
  367. Self: The instance of the `StateGraph`, allowing for method chaining.
  368. """
  369. if (retry := kwargs.get("retry", MISSING)) is not MISSING:
  370. warnings.warn(
  371. "`retry` is deprecated and will be removed. Please use `retry_policy` instead.",
  372. category=LangGraphDeprecatedSinceV05,
  373. )
  374. if retry_policy is None:
  375. retry_policy = retry # type: ignore[assignment]
  376. if (input_ := kwargs.get("input", MISSING)) is not MISSING:
  377. warnings.warn(
  378. "`input` is deprecated and will be removed. Please use `input_schema` instead.",
  379. category=LangGraphDeprecatedSinceV05,
  380. )
  381. if input_schema is None:
  382. input_schema = cast(type[NodeInputT] | None, input_)
  383. if not isinstance(node, str):
  384. action = node
  385. if isinstance(action, Runnable):
  386. node = action.get_name()
  387. else:
  388. node = getattr(action, "__name__", action.__class__.__name__)
  389. if node is None:
  390. raise ValueError(
  391. "Node name must be provided if action is not a function"
  392. )
  393. if self.compiled:
  394. logger.warning(
  395. "Adding a node to a graph that has already been compiled. This will "
  396. "not be reflected in the compiled graph."
  397. )
  398. if not isinstance(node, str):
  399. action = node
  400. node = cast(str, getattr(action, "name", getattr(action, "__name__", None)))
  401. if node is None:
  402. raise ValueError(
  403. "Node name must be provided if action is not a function"
  404. )
  405. if action is None:
  406. raise RuntimeError
  407. if node in self.nodes:
  408. raise ValueError(f"Node `{node}` already present.")
  409. if node == END or node == START:
  410. raise ValueError(f"Node `{node}` is reserved.")
  411. for character in (NS_SEP, NS_END):
  412. if character in node:
  413. raise ValueError(
  414. f"'{character}' is a reserved character and is not allowed in the node names."
  415. )
  416. inferred_input_schema = None
  417. ends: tuple[str, ...] | dict[str, str] = EMPTY_SEQ
  418. try:
  419. if (
  420. isfunction(action)
  421. or ismethod(action)
  422. or ismethod(getattr(action, "__call__", None))
  423. ) and (
  424. hints := get_type_hints(getattr(action, "__call__"))
  425. or get_type_hints(action)
  426. ):
  427. if input_schema is None:
  428. first_parameter_name = next(
  429. iter(
  430. inspect.signature(
  431. cast(FunctionType, action)
  432. ).parameters.keys()
  433. )
  434. )
  435. if input_hint := hints.get(first_parameter_name):
  436. if isinstance(input_hint, type) and get_type_hints(input_hint):
  437. inferred_input_schema = input_hint
  438. if rtn := hints.get("return"):
  439. # Handle Union types
  440. rtn_origin = get_origin(rtn)
  441. if rtn_origin is Union:
  442. rtn_args = get_args(rtn)
  443. # Look for Command in the union
  444. for arg in rtn_args:
  445. arg_origin = get_origin(arg)
  446. if arg_origin is Command:
  447. rtn = arg
  448. rtn_origin = arg_origin
  449. break
  450. # Check if it's a Command type
  451. if (
  452. rtn_origin is Command
  453. and (rargs := get_args(rtn))
  454. and get_origin(rargs[0]) is Literal
  455. and (vals := get_args(rargs[0]))
  456. ):
  457. ends = vals
  458. except (NameError, TypeError, StopIteration):
  459. pass
  460. if destinations is not None:
  461. ends = destinations
  462. if input_schema is not None:
  463. self.nodes[node] = StateNodeSpec[NodeInputT, ContextT](
  464. coerce_to_runnable(action, name=node, trace=False), # type: ignore[arg-type]
  465. metadata,
  466. input_schema=input_schema,
  467. retry_policy=retry_policy,
  468. cache_policy=cache_policy,
  469. ends=ends,
  470. defer=defer,
  471. )
  472. elif inferred_input_schema is not None:
  473. self.nodes[node] = StateNodeSpec(
  474. coerce_to_runnable(action, name=node, trace=False), # type: ignore[arg-type]
  475. metadata,
  476. input_schema=inferred_input_schema,
  477. retry_policy=retry_policy,
  478. cache_policy=cache_policy,
  479. ends=ends,
  480. defer=defer,
  481. )
  482. else:
  483. self.nodes[node] = StateNodeSpec[StateT, ContextT](
  484. coerce_to_runnable(action, name=node, trace=False), # type: ignore[arg-type]
  485. metadata,
  486. input_schema=self.state_schema,
  487. retry_policy=retry_policy,
  488. cache_policy=cache_policy,
  489. ends=ends,
  490. defer=defer,
  491. )
  492. input_schema = input_schema or inferred_input_schema
  493. if input_schema is not None:
  494. self._add_schema(input_schema)
  495. return self
  496. def add_edge(self, start_key: str | list[str], end_key: str) -> Self:
  497. """Add a directed edge from the start node (or list of start nodes) to the end node.
  498. When a single start node is provided, the graph will wait for that node to complete
  499. before executing the end node. When multiple start nodes are provided,
  500. the graph will wait for ALL of the start nodes to complete before executing the end node.
  501. Args:
  502. start_key: The key(s) of the start node(s) of the edge.
  503. end_key: The key of the end node of the edge.
  504. Raises:
  505. ValueError: If the start key is `'END'` or if the start key or end key is not present in the graph.
  506. Returns:
  507. Self: The instance of the `StateGraph`, allowing for method chaining.
  508. """
  509. if self.compiled:
  510. logger.warning(
  511. "Adding an edge to a graph that has already been compiled. This will "
  512. "not be reflected in the compiled graph."
  513. )
  514. if isinstance(start_key, str):
  515. if start_key == END:
  516. raise ValueError("END cannot be a start node")
  517. if end_key == START:
  518. raise ValueError("START cannot be an end node")
  519. # run this validation only for non-StateGraph graphs
  520. if not hasattr(self, "channels") and start_key in set(
  521. start for start, _ in self.edges
  522. ):
  523. raise ValueError(
  524. f"Already found path for node '{start_key}'.\n"
  525. "For multiple edges, use StateGraph with an Annotated state key."
  526. )
  527. self.edges.add((start_key, end_key))
  528. return self
  529. for start in start_key:
  530. if start == END:
  531. raise ValueError("END cannot be a start node")
  532. if start not in self.nodes:
  533. raise ValueError(f"Need to add_node `{start}` first")
  534. if end_key == START:
  535. raise ValueError("START cannot be an end node")
  536. if end_key != END and end_key not in self.nodes:
  537. raise ValueError(f"Need to add_node `{end_key}` first")
  538. self.waiting_edges.add((tuple(start_key), end_key))
  539. return self
  540. def add_conditional_edges(
  541. self,
  542. source: str,
  543. path: Callable[..., Hashable | Sequence[Hashable]]
  544. | Callable[..., Awaitable[Hashable | Sequence[Hashable]]]
  545. | Runnable[Any, Hashable | Sequence[Hashable]],
  546. path_map: dict[Hashable, str] | list[str] | None = None,
  547. ) -> Self:
  548. """Add a conditional edge from the starting node to any number of destination nodes.
  549. Args:
  550. source: The starting node. This conditional edge will run when
  551. exiting this node.
  552. path: The callable that determines the next node or nodes.
  553. If not specifying `path_map` it should return one or more nodes.
  554. If it returns `'END'`, the graph will stop execution.
  555. path_map: Optional mapping of paths to node names.
  556. If omitted the paths returned by `path` should be node names.
  557. Returns:
  558. Self: The instance of the graph, allowing for method chaining.
  559. !!! warning
  560. Without type hints on the `path` function's return value (e.g., `-> Literal["foo", "__end__"]:`)
  561. or a path_map, the graph visualization assumes the edge could transition to any node in the graph.
  562. """ # noqa: E501
  563. if self.compiled:
  564. logger.warning(
  565. "Adding an edge to a graph that has already been compiled. This will "
  566. "not be reflected in the compiled graph."
  567. )
  568. # find a name for the condition
  569. path = coerce_to_runnable(path, name=None, trace=True)
  570. name = path.name or "condition"
  571. # validate the condition
  572. if name in self.branches[source]:
  573. raise ValueError(
  574. f"Branch with name `{path.name}` already exists for node `{source}`"
  575. )
  576. # save it
  577. self.branches[source][name] = BranchSpec.from_path(path, path_map, True)
  578. if schema := self.branches[source][name].input_schema:
  579. self._add_schema(schema)
  580. return self
  581. def add_sequence(
  582. self,
  583. nodes: Sequence[
  584. StateNode[NodeInputT, ContextT]
  585. | tuple[str, StateNode[NodeInputT, ContextT]]
  586. ],
  587. ) -> Self:
  588. """Add a sequence of nodes that will be executed in the provided order.
  589. Args:
  590. nodes: A sequence of `StateNode` (callables that accept a `state` arg) or `(name, StateNode)` tuples.
  591. If no names are provided, the name will be inferred from the node object (e.g. a `Runnable` or a `Callable` name).
  592. Each node will be executed in the order provided.
  593. Raises:
  594. ValueError: If the sequence is empty.
  595. ValueError: If the sequence contains duplicate node names.
  596. Returns:
  597. Self: The instance of the `StateGraph`, allowing for method chaining.
  598. """
  599. if len(nodes) < 1:
  600. raise ValueError("Sequence requires at least one node.")
  601. previous_name: str | None = None
  602. for node in nodes:
  603. if isinstance(node, tuple) and len(node) == 2:
  604. name, node = node
  605. else:
  606. name = _get_node_name(node)
  607. if name in self.nodes:
  608. raise ValueError(
  609. f"Node names must be unique: node with the name '{name}' already exists. "
  610. "If you need to use two different runnables/callables with the same name (for example, using `lambda`), please provide them as tuples (name, runnable/callable)."
  611. )
  612. self.add_node(name, node)
  613. if previous_name is not None:
  614. self.add_edge(previous_name, name)
  615. previous_name = name
  616. return self
  617. def set_entry_point(self, key: str) -> Self:
  618. """Specifies the first node to be called in the graph.
  619. Equivalent to calling `add_edge(START, key)`.
  620. Parameters:
  621. key (str): The key of the node to set as the entry point.
  622. Returns:
  623. Self: The instance of the graph, allowing for method chaining.
  624. """
  625. return self.add_edge(START, key)
  626. def set_conditional_entry_point(
  627. self,
  628. path: Callable[..., Hashable | Sequence[Hashable]]
  629. | Callable[..., Awaitable[Hashable | Sequence[Hashable]]]
  630. | Runnable[Any, Hashable | Sequence[Hashable]],
  631. path_map: dict[Hashable, str] | list[str] | None = None,
  632. ) -> Self:
  633. """Sets a conditional entry point in the graph.
  634. Args:
  635. path: The callable that determines the next node or nodes.
  636. If not specifying `path_map` it should return one or more nodes.
  637. If it returns END, the graph will stop execution.
  638. path_map: Optional mapping of paths to node names.
  639. If omitted the paths returned by `path` should be node names.
  640. Returns:
  641. Self: The instance of the graph, allowing for method chaining.
  642. """
  643. return self.add_conditional_edges(START, path, path_map)
  644. def set_finish_point(self, key: str) -> Self:
  645. """Marks a node as a finish point of the graph.
  646. If the graph reaches this node, it will cease execution.
  647. Parameters:
  648. key (str): The key of the node to set as the finish point.
  649. Returns:
  650. Self: The instance of the graph, allowing for method chaining.
  651. """
  652. return self.add_edge(key, END)
  653. def validate(self, interrupt: Sequence[str] | None = None) -> Self:
  654. # assemble sources
  655. all_sources = {src for src, _ in self._all_edges}
  656. for start, branches in self.branches.items():
  657. all_sources.add(start)
  658. for name, spec in self.nodes.items():
  659. if spec.ends:
  660. all_sources.add(name)
  661. # validate sources
  662. for source in all_sources:
  663. if source not in self.nodes and source != START:
  664. raise ValueError(f"Found edge starting at unknown node '{source}'")
  665. if START not in all_sources:
  666. raise ValueError(
  667. "Graph must have an entrypoint: add at least one edge from START to another node"
  668. )
  669. # assemble targets
  670. all_targets = {end for _, end in self._all_edges}
  671. for start, branches in self.branches.items():
  672. for cond, branch in branches.items():
  673. if branch.ends is not None:
  674. for end in branch.ends.values():
  675. if end not in self.nodes and end != END:
  676. raise ValueError(
  677. f"At '{start}' node, '{cond}' branch found unknown target '{end}'"
  678. )
  679. all_targets.add(end)
  680. else:
  681. all_targets.add(END)
  682. for node in self.nodes:
  683. if node != start:
  684. all_targets.add(node)
  685. for name, spec in self.nodes.items():
  686. if spec.ends:
  687. all_targets.update(spec.ends)
  688. for target in all_targets:
  689. if target not in self.nodes and target != END:
  690. raise ValueError(f"Found edge ending at unknown node `{target}`")
  691. # validate interrupts
  692. if interrupt:
  693. for node in interrupt:
  694. if node not in self.nodes:
  695. raise ValueError(f"Interrupt node `{node}` not found")
  696. self.compiled = True
  697. return self
  698. def compile(
  699. self,
  700. checkpointer: Checkpointer = None,
  701. *,
  702. cache: BaseCache | None = None,
  703. store: BaseStore | None = None,
  704. interrupt_before: All | list[str] | None = None,
  705. interrupt_after: All | list[str] | None = None,
  706. debug: bool = False,
  707. name: str | None = None,
  708. ) -> CompiledStateGraph[StateT, ContextT, InputT, OutputT]:
  709. """Compiles the `StateGraph` into a `CompiledStateGraph` object.
  710. The compiled graph implements the `Runnable` interface and can be invoked,
  711. streamed, batched, and run asynchronously.
  712. Args:
  713. checkpointer: A checkpoint saver object or flag.
  714. If provided, this `Checkpointer` serves as a fully versioned "short-term memory" for the graph,
  715. allowing it to be paused, resumed, and replayed from any point.
  716. If `None`, it may inherit the parent graph's checkpointer when used as a subgraph.
  717. If `False`, it will not use or inherit any checkpointer.
  718. interrupt_before: An optional list of node names to interrupt before.
  719. interrupt_after: An optional list of node names to interrupt after.
  720. debug: A flag indicating whether to enable debug mode.
  721. name: The name to use for the compiled graph.
  722. Returns:
  723. CompiledStateGraph: The compiled `StateGraph`.
  724. """
  725. # assign default values
  726. interrupt_before = interrupt_before or []
  727. interrupt_after = interrupt_after or []
  728. # validate the graph
  729. self.validate(
  730. interrupt=(
  731. (interrupt_before if interrupt_before != "*" else []) + interrupt_after
  732. if interrupt_after != "*"
  733. else []
  734. )
  735. )
  736. # prepare output channels
  737. output_channels = (
  738. "__root__"
  739. if len(self.schemas[self.output_schema]) == 1
  740. and "__root__" in self.schemas[self.output_schema]
  741. else [
  742. key
  743. for key, val in self.schemas[self.output_schema].items()
  744. if not is_managed_value(val)
  745. ]
  746. )
  747. stream_channels = (
  748. "__root__"
  749. if len(self.channels) == 1 and "__root__" in self.channels
  750. else [
  751. key for key, val in self.channels.items() if not is_managed_value(val)
  752. ]
  753. )
  754. compiled = CompiledStateGraph[StateT, ContextT, InputT, OutputT](
  755. builder=self,
  756. schema_to_mapper={},
  757. context_schema=self.context_schema,
  758. nodes={},
  759. channels={
  760. **self.channels,
  761. **self.managed,
  762. START: EphemeralValue(self.input_schema),
  763. },
  764. input_channels=START,
  765. stream_mode="updates",
  766. output_channels=output_channels,
  767. stream_channels=stream_channels,
  768. checkpointer=checkpointer,
  769. interrupt_before_nodes=interrupt_before,
  770. interrupt_after_nodes=interrupt_after,
  771. auto_validate=False,
  772. debug=debug,
  773. store=store,
  774. cache=cache,
  775. name=name or "LangGraph",
  776. )
  777. compiled.attach_node(START, None)
  778. for key, node in self.nodes.items():
  779. compiled.attach_node(key, node)
  780. for start, end in self.edges:
  781. compiled.attach_edge(start, end)
  782. for starts, end in self.waiting_edges:
  783. compiled.attach_edge(starts, end)
  784. for start, branches in self.branches.items():
  785. for name, branch in branches.items():
  786. compiled.attach_branch(start, name, branch)
  787. return compiled.validate()
  788. class CompiledStateGraph(
  789. Pregel[StateT, ContextT, InputT, OutputT],
  790. Generic[StateT, ContextT, InputT, OutputT],
  791. ):
  792. builder: StateGraph[StateT, ContextT, InputT, OutputT]
  793. schema_to_mapper: dict[type[Any], Callable[[Any], Any] | None]
  794. def __init__(
  795. self,
  796. *,
  797. builder: StateGraph[StateT, ContextT, InputT, OutputT],
  798. schema_to_mapper: dict[type[Any], Callable[[Any], Any] | None],
  799. **kwargs: Any,
  800. ) -> None:
  801. super().__init__(**kwargs)
  802. self.builder = builder
  803. self.schema_to_mapper = schema_to_mapper
  804. def get_input_jsonschema(
  805. self, config: RunnableConfig | None = None
  806. ) -> dict[str, Any]:
  807. return _get_json_schema(
  808. typ=self.builder.input_schema,
  809. schemas=self.builder.schemas,
  810. channels=self.builder.channels,
  811. name=self.get_name("Input"),
  812. )
  813. def get_output_jsonschema(
  814. self, config: RunnableConfig | None = None
  815. ) -> dict[str, Any]:
  816. return _get_json_schema(
  817. typ=self.builder.output_schema,
  818. schemas=self.builder.schemas,
  819. channels=self.builder.channels,
  820. name=self.get_name("Output"),
  821. )
  822. def attach_node(self, key: str, node: StateNodeSpec[Any, ContextT] | None) -> None:
  823. if key == START:
  824. output_keys = [
  825. k
  826. for k, v in self.builder.schemas[self.builder.input_schema].items()
  827. if not is_managed_value(v)
  828. ]
  829. else:
  830. output_keys = list(self.builder.channels) + [
  831. k for k, v in self.builder.managed.items()
  832. ]
  833. def _get_updates(
  834. input: None | dict | Any,
  835. ) -> Sequence[tuple[str, Any]] | None:
  836. if input is None:
  837. return None
  838. elif isinstance(input, dict):
  839. return [(k, v) for k, v in input.items() if k in output_keys]
  840. elif isinstance(input, Command):
  841. if input.graph == Command.PARENT:
  842. return None
  843. return [
  844. (k, v) for k, v in input._update_as_tuples() if k in output_keys
  845. ]
  846. elif (
  847. isinstance(input, (list, tuple))
  848. and input
  849. and any(isinstance(i, Command) for i in input)
  850. ):
  851. updates: list[tuple[str, Any]] = []
  852. for i in input:
  853. if isinstance(i, Command):
  854. if i.graph == Command.PARENT:
  855. continue
  856. updates.extend(
  857. (k, v) for k, v in i._update_as_tuples() if k in output_keys
  858. )
  859. else:
  860. updates.extend(_get_updates(i) or ())
  861. return updates
  862. elif (t := type(input)) and get_cached_annotated_keys(t):
  863. return get_update_as_tuples(input, output_keys)
  864. else:
  865. msg = create_error_message(
  866. message=f"Expected dict, got {input}",
  867. error_code=ErrorCode.INVALID_GRAPH_NODE_RETURN_VALUE,
  868. )
  869. raise InvalidUpdateError(msg)
  870. # state updaters
  871. write_entries: tuple[ChannelWriteEntry | ChannelWriteTupleEntry, ...] = (
  872. ChannelWriteTupleEntry(
  873. mapper=_get_root if output_keys == ["__root__"] else _get_updates
  874. ),
  875. ChannelWriteTupleEntry(
  876. mapper=_control_branch,
  877. static=_control_static(node.ends)
  878. if node is not None and node.ends is not None
  879. else None,
  880. ),
  881. )
  882. # add node and output channel
  883. if key == START:
  884. self.nodes[key] = PregelNode(
  885. tags=[TAG_HIDDEN],
  886. triggers=[START],
  887. channels=START,
  888. writers=[ChannelWrite(write_entries)],
  889. )
  890. elif node is not None:
  891. input_schema = node.input_schema if node else self.builder.state_schema
  892. input_channels = list(self.builder.schemas[input_schema])
  893. is_single_input = len(input_channels) == 1 and "__root__" in input_channels
  894. if input_schema in self.schema_to_mapper:
  895. mapper = self.schema_to_mapper[input_schema]
  896. else:
  897. mapper = _pick_mapper(input_channels, input_schema)
  898. self.schema_to_mapper[input_schema] = mapper
  899. branch_channel = _CHANNEL_BRANCH_TO.format(key)
  900. self.channels[branch_channel] = (
  901. LastValueAfterFinish(Any)
  902. if node.defer
  903. else EphemeralValue(Any, guard=False)
  904. )
  905. self.nodes[key] = PregelNode(
  906. triggers=[branch_channel],
  907. # read state keys and managed values
  908. channels=("__root__" if is_single_input else input_channels),
  909. # coerce state dict to schema class (eg. pydantic model)
  910. mapper=mapper,
  911. # publish to state keys
  912. writers=[ChannelWrite(write_entries)],
  913. metadata=node.metadata,
  914. retry_policy=node.retry_policy,
  915. cache_policy=node.cache_policy,
  916. bound=node.runnable, # type: ignore[arg-type]
  917. )
  918. else:
  919. raise RuntimeError
  920. def attach_edge(self, starts: str | Sequence[str], end: str) -> None:
  921. if isinstance(starts, str):
  922. # subscribe to start channel
  923. if end != END:
  924. self.nodes[starts].writers.append(
  925. ChannelWrite(
  926. (ChannelWriteEntry(_CHANNEL_BRANCH_TO.format(end), None),)
  927. )
  928. )
  929. elif end != END:
  930. channel_name = f"join:{'+'.join(starts)}:{end}"
  931. # register channel
  932. if self.builder.nodes[end].defer:
  933. self.channels[channel_name] = NamedBarrierValueAfterFinish(
  934. str, set(starts)
  935. )
  936. else:
  937. self.channels[channel_name] = NamedBarrierValue(str, set(starts))
  938. # subscribe to channel
  939. self.nodes[end].triggers.append(channel_name)
  940. # publish to channel
  941. for start in starts:
  942. self.nodes[start].writers.append(
  943. ChannelWrite((ChannelWriteEntry(channel_name, start),))
  944. )
  945. def attach_branch(
  946. self, start: str, name: str, branch: BranchSpec, *, with_reader: bool = True
  947. ) -> None:
  948. def get_writes(
  949. packets: Sequence[str | Send], static: bool = False
  950. ) -> Sequence[ChannelWriteEntry | Send]:
  951. writes = [
  952. (
  953. ChannelWriteEntry(
  954. p if p == END else _CHANNEL_BRANCH_TO.format(p), None
  955. )
  956. if not isinstance(p, Send)
  957. else p
  958. )
  959. for p in packets
  960. if (True if static else p != END)
  961. ]
  962. if not writes:
  963. return []
  964. return writes
  965. if with_reader:
  966. # get schema
  967. schema = branch.input_schema or (
  968. self.builder.nodes[start].input_schema
  969. if start in self.builder.nodes
  970. else self.builder.state_schema
  971. )
  972. channels = list(self.builder.schemas[schema])
  973. # get mapper
  974. if schema in self.schema_to_mapper:
  975. mapper = self.schema_to_mapper[schema]
  976. else:
  977. mapper = _pick_mapper(channels, schema)
  978. self.schema_to_mapper[schema] = mapper
  979. # create reader
  980. reader: Callable[[RunnableConfig], Any] | None = partial(
  981. ChannelRead.do_read,
  982. select=channels[0] if channels == ["__root__"] else channels,
  983. fresh=True,
  984. # coerce state dict to schema class (eg. pydantic model)
  985. mapper=mapper,
  986. )
  987. else:
  988. reader = None
  989. # attach branch publisher
  990. self.nodes[start].writers.append(branch.run(get_writes, reader))
  991. def _migrate_checkpoint(self, checkpoint: Checkpoint) -> None:
  992. """Migrate a checkpoint to new channel layout."""
  993. super()._migrate_checkpoint(checkpoint)
  994. values = checkpoint["channel_values"]
  995. versions = checkpoint["channel_versions"]
  996. seen = checkpoint["versions_seen"]
  997. # empty checkpoints do not need migration
  998. if not versions:
  999. return
  1000. # current version
  1001. if checkpoint["v"] >= 3:
  1002. return
  1003. # Migrate from start:node to branch:to:node
  1004. for k in list(versions):
  1005. if k.startswith("start:"):
  1006. # confirm node is present
  1007. node = k.split(":")[1]
  1008. if node not in self.nodes:
  1009. continue
  1010. # get next version
  1011. new_k = f"branch:to:{node}"
  1012. new_v = (
  1013. max(versions[new_k], versions.pop(k))
  1014. if new_k in versions
  1015. else versions.pop(k)
  1016. )
  1017. # update seen
  1018. for ss in (seen.get(node, {}), seen.get(INTERRUPT, {})):
  1019. if k in ss:
  1020. s = ss.pop(k)
  1021. if new_k in ss:
  1022. ss[new_k] = max(s, ss[new_k])
  1023. else:
  1024. ss[new_k] = s
  1025. # update value
  1026. if new_k not in values and k in values:
  1027. values[new_k] = values.pop(k)
  1028. # update version
  1029. versions[new_k] = new_v
  1030. # Migrate from branch:source:condition:node to branch:to:node
  1031. for k in list(versions):
  1032. if k.startswith("branch:") and k.count(":") == 3:
  1033. # confirm node is present
  1034. node = k.split(":")[-1]
  1035. if node not in self.nodes:
  1036. continue
  1037. # get next version
  1038. new_k = f"branch:to:{node}"
  1039. new_v = (
  1040. max(versions[new_k], versions.pop(k))
  1041. if new_k in versions
  1042. else versions.pop(k)
  1043. )
  1044. # update seen
  1045. for ss in (seen.get(node, {}), seen.get(INTERRUPT, {})):
  1046. if k in ss:
  1047. s = ss.pop(k)
  1048. if new_k in ss:
  1049. ss[new_k] = max(s, ss[new_k])
  1050. else:
  1051. ss[new_k] = s
  1052. # update value
  1053. if new_k not in values and k in values:
  1054. values[new_k] = values.pop(k)
  1055. # update version
  1056. versions[new_k] = new_v
  1057. if not set(self.nodes).isdisjoint(versions):
  1058. # Migrate from "node" to "branch:to:node"
  1059. source_to_target = defaultdict(list)
  1060. for start, end in self.builder.edges:
  1061. if start != START and end != END:
  1062. source_to_target[start].append(end)
  1063. for k in list(versions):
  1064. if k == START:
  1065. continue
  1066. if k in self.nodes:
  1067. v = versions.pop(k)
  1068. c = values.pop(k, MISSING)
  1069. for end in source_to_target[k]:
  1070. # get next version
  1071. new_k = f"branch:to:{end}"
  1072. new_v = max(versions[new_k], v) if new_k in versions else v
  1073. # update seen
  1074. for ss in (seen.get(end, {}), seen.get(INTERRUPT, {})):
  1075. if k in ss:
  1076. s = ss.pop(k)
  1077. if new_k in ss:
  1078. ss[new_k] = max(s, ss[new_k])
  1079. else:
  1080. ss[new_k] = s
  1081. # update value
  1082. if new_k not in values and c is not MISSING:
  1083. values[new_k] = c
  1084. # update version
  1085. versions[new_k] = new_v
  1086. # pop interrupt seen
  1087. if INTERRUPT in seen:
  1088. seen[INTERRUPT].pop(k, MISSING)
  1089. def _pick_mapper(
  1090. state_keys: Sequence[str], schema: type[Any]
  1091. ) -> Callable[[Any], Any] | None:
  1092. if state_keys == ["__root__"]:
  1093. return None
  1094. if isclass(schema) and issubclass(schema, dict):
  1095. return None
  1096. return partial(_coerce_state, schema)
  1097. def _coerce_state(schema: type[Any], input: dict[str, Any]) -> dict[str, Any]:
  1098. return schema(**input)
  1099. def _control_branch(value: Any) -> Sequence[tuple[str, Any]]:
  1100. if isinstance(value, Send):
  1101. return ((TASKS, value),)
  1102. commands: list[Command] = []
  1103. if isinstance(value, Command):
  1104. commands.append(value)
  1105. elif isinstance(value, (list, tuple)):
  1106. for cmd in value:
  1107. if isinstance(cmd, Command):
  1108. commands.append(cmd)
  1109. rtn: list[tuple[str, Any]] = []
  1110. for command in commands:
  1111. if command.graph == Command.PARENT:
  1112. raise ParentCommand(command)
  1113. goto_targets = (
  1114. [command.goto] if isinstance(command.goto, (Send, str)) else command.goto
  1115. )
  1116. for go in goto_targets:
  1117. if isinstance(go, Send):
  1118. rtn.append((TASKS, go))
  1119. elif isinstance(go, str) and go != END:
  1120. # END is a special case, it's not actually a node in a practical sense
  1121. # but rather a special terminal node that we don't need to branch to
  1122. rtn.append((_CHANNEL_BRANCH_TO.format(go), None))
  1123. return rtn
  1124. def _control_static(
  1125. ends: tuple[str, ...] | dict[str, str],
  1126. ) -> Sequence[tuple[str, Any, str | None]]:
  1127. if isinstance(ends, dict):
  1128. return [
  1129. (k if k == END else _CHANNEL_BRANCH_TO.format(k), None, label)
  1130. for k, label in ends.items()
  1131. ]
  1132. else:
  1133. return [
  1134. (e if e == END else _CHANNEL_BRANCH_TO.format(e), None, None) for e in ends
  1135. ]
  1136. def _get_root(input: Any) -> Sequence[tuple[str, Any]] | None:
  1137. if isinstance(input, Command):
  1138. if input.graph == Command.PARENT:
  1139. return ()
  1140. return input._update_as_tuples()
  1141. elif (
  1142. isinstance(input, (list, tuple))
  1143. and input
  1144. and any(isinstance(i, Command) for i in input)
  1145. ):
  1146. updates: list[tuple[str, Any]] = []
  1147. for i in input:
  1148. if isinstance(i, Command):
  1149. if i.graph == Command.PARENT:
  1150. continue
  1151. updates.extend(i._update_as_tuples())
  1152. else:
  1153. updates.append(("__root__", i))
  1154. return updates
  1155. elif input is not None:
  1156. return [("__root__", input)]
  1157. def _get_channels(
  1158. schema: type[dict],
  1159. ) -> tuple[dict[str, BaseChannel], dict[str, ManagedValueSpec], dict[str, Any]]:
  1160. if not hasattr(schema, "__annotations__"):
  1161. return (
  1162. {"__root__": _get_channel("__root__", schema, allow_managed=False)},
  1163. {},
  1164. {},
  1165. )
  1166. type_hints = get_type_hints(schema, include_extras=True)
  1167. all_keys = {
  1168. name: _get_channel(name, typ)
  1169. for name, typ in type_hints.items()
  1170. if name != "__slots__"
  1171. }
  1172. return (
  1173. {k: v for k, v in all_keys.items() if isinstance(v, BaseChannel)},
  1174. {k: v for k, v in all_keys.items() if is_managed_value(v)},
  1175. type_hints,
  1176. )
  1177. @overload
  1178. def _get_channel(
  1179. name: str, annotation: Any, *, allow_managed: Literal[False]
  1180. ) -> BaseChannel: ...
  1181. @overload
  1182. def _get_channel(
  1183. name: str, annotation: Any, *, allow_managed: Literal[True] = True
  1184. ) -> BaseChannel | ManagedValueSpec: ...
  1185. def _get_channel(
  1186. name: str, annotation: Any, *, allow_managed: bool = True
  1187. ) -> BaseChannel | ManagedValueSpec:
  1188. # Strip out Required and NotRequired wrappers
  1189. if hasattr(annotation, "__origin__") and annotation.__origin__ in (
  1190. Required,
  1191. NotRequired,
  1192. ):
  1193. annotation = annotation.__args__[0]
  1194. if manager := _is_field_managed_value(name, annotation):
  1195. if allow_managed:
  1196. return manager
  1197. else:
  1198. raise ValueError(f"This {annotation} not allowed in this position")
  1199. elif channel := _is_field_channel(annotation):
  1200. channel.key = name
  1201. return channel
  1202. elif channel := _is_field_binop(annotation):
  1203. channel.key = name
  1204. return channel
  1205. fallback: LastValue = LastValue(annotation)
  1206. fallback.key = name
  1207. return fallback
  1208. def _is_field_channel(typ: type[Any]) -> BaseChannel | None:
  1209. if hasattr(typ, "__metadata__"):
  1210. meta = typ.__metadata__
  1211. # Search through all annotated medata to find channel annotations
  1212. for item in meta:
  1213. if isinstance(item, BaseChannel):
  1214. return item
  1215. elif isclass(item) and issubclass(item, BaseChannel):
  1216. # ex, Annotated[int, EphemeralValue, SomeOtherAnnotation]
  1217. # would return EphemeralValue(int)
  1218. return item(typ.__origin__ if hasattr(typ, "__origin__") else typ)
  1219. return None
  1220. def _is_field_binop(typ: type[Any]) -> BinaryOperatorAggregate | None:
  1221. if hasattr(typ, "__metadata__"):
  1222. meta = typ.__metadata__
  1223. if len(meta) >= 1 and callable(meta[-1]):
  1224. sig = signature(meta[-1])
  1225. params = list(sig.parameters.values())
  1226. if (
  1227. sum(
  1228. p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
  1229. for p in params
  1230. )
  1231. == 2
  1232. ):
  1233. return BinaryOperatorAggregate(typ, meta[-1])
  1234. else:
  1235. raise ValueError(
  1236. f"Invalid reducer signature. Expected (a, b) -> c. Got {sig}"
  1237. )
  1238. return None
  1239. def _is_field_managed_value(name: str, typ: type[Any]) -> ManagedValueSpec | None:
  1240. if hasattr(typ, "__metadata__"):
  1241. meta = typ.__metadata__
  1242. if len(meta) >= 1:
  1243. decoration = get_origin(meta[-1]) or meta[-1]
  1244. if is_managed_value(decoration):
  1245. return decoration
  1246. # Handle Required, NotRequired, etc wrapped types by extracting the inner type
  1247. if (
  1248. get_origin(typ) is not None
  1249. and (args := get_args(typ))
  1250. and (inner_type := args[0])
  1251. ):
  1252. return _is_field_managed_value(name, inner_type)
  1253. return None
  1254. def _get_json_schema(
  1255. typ: type,
  1256. schemas: dict,
  1257. channels: dict,
  1258. name: str,
  1259. ) -> dict[str, Any]:
  1260. if isclass(typ) and issubclass(typ, BaseModel):
  1261. return typ.model_json_schema()
  1262. elif is_typeddict(typ):
  1263. return TypeAdapter(typ).json_schema()
  1264. else:
  1265. keys = list(schemas[typ].keys())
  1266. if len(keys) == 1 and keys[0] == "__root__":
  1267. return create_model(
  1268. name,
  1269. root=(channels[keys[0]].UpdateType, None),
  1270. ).model_json_schema()
  1271. else:
  1272. return create_model(
  1273. name,
  1274. field_definitions={
  1275. k: (
  1276. channels[k].UpdateType,
  1277. (
  1278. get_field_default(
  1279. k,
  1280. channels[k].UpdateType,
  1281. typ,
  1282. )
  1283. ),
  1284. )
  1285. for k in schemas[typ]
  1286. if k in channels and isinstance(channels[k], BaseChannel)
  1287. },
  1288. ).model_json_schema()