main.py 128 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319
  1. from __future__ import annotations
  2. import asyncio
  3. import concurrent
  4. import concurrent.futures
  5. import contextlib
  6. import queue
  7. import warnings
  8. import weakref
  9. from collections import defaultdict, deque
  10. from collections.abc import (
  11. AsyncIterator,
  12. Awaitable,
  13. Callable,
  14. Iterator,
  15. Mapping,
  16. Sequence,
  17. )
  18. from dataclasses import is_dataclass
  19. from functools import partial
  20. from inspect import isclass
  21. from typing import (
  22. Any,
  23. Generic,
  24. cast,
  25. get_type_hints,
  26. )
  27. from uuid import UUID, uuid5
  28. from langchain_core.globals import get_debug
  29. from langchain_core.runnables import (
  30. RunnableSequence,
  31. )
  32. from langchain_core.runnables.base import Input, Output
  33. from langchain_core.runnables.config import (
  34. RunnableConfig,
  35. get_async_callback_manager_for_config,
  36. get_callback_manager_for_config,
  37. )
  38. from langchain_core.runnables.graph import Graph
  39. from langgraph.cache.base import BaseCache
  40. from langgraph.checkpoint.base import (
  41. BaseCheckpointSaver,
  42. Checkpoint,
  43. CheckpointTuple,
  44. )
  45. from langgraph.store.base import BaseStore
  46. from pydantic import BaseModel, TypeAdapter
  47. from typing_extensions import Self, Unpack, deprecated, is_typeddict
  48. from langgraph._internal._config import (
  49. ensure_config,
  50. merge_configs,
  51. patch_checkpoint_map,
  52. patch_config,
  53. patch_configurable,
  54. recast_checkpoint_ns,
  55. )
  56. from langgraph._internal._constants import (
  57. CACHE_NS_WRITES,
  58. CONF,
  59. CONFIG_KEY_CACHE,
  60. CONFIG_KEY_CHECKPOINT_ID,
  61. CONFIG_KEY_CHECKPOINT_NS,
  62. CONFIG_KEY_CHECKPOINTER,
  63. CONFIG_KEY_DURABILITY,
  64. CONFIG_KEY_NODE_FINISHED,
  65. CONFIG_KEY_READ,
  66. CONFIG_KEY_RUNNER_SUBMIT,
  67. CONFIG_KEY_RUNTIME,
  68. CONFIG_KEY_SEND,
  69. CONFIG_KEY_STREAM,
  70. CONFIG_KEY_TASK_ID,
  71. CONFIG_KEY_THREAD_ID,
  72. ERROR,
  73. INPUT,
  74. INTERRUPT,
  75. NS_END,
  76. NS_SEP,
  77. NULL_TASK_ID,
  78. PUSH,
  79. TASKS,
  80. )
  81. from langgraph._internal._pydantic import create_model
  82. from langgraph._internal._queue import ( # type: ignore[attr-defined]
  83. AsyncQueue,
  84. SyncQueue,
  85. )
  86. from langgraph._internal._runnable import (
  87. Runnable,
  88. RunnableLike,
  89. RunnableSeq,
  90. coerce_to_runnable,
  91. )
  92. from langgraph._internal._typing import MISSING, DeprecatedKwargs
  93. from langgraph.channels.base import BaseChannel
  94. from langgraph.channels.topic import Topic
  95. from langgraph.config import get_config
  96. from langgraph.constants import END
  97. from langgraph.errors import (
  98. ErrorCode,
  99. GraphRecursionError,
  100. InvalidUpdateError,
  101. create_error_message,
  102. )
  103. from langgraph.managed.base import ManagedValueSpec
  104. from langgraph.pregel._algo import (
  105. PregelTaskWrites,
  106. _scratchpad,
  107. apply_writes,
  108. local_read,
  109. prepare_next_tasks,
  110. )
  111. from langgraph.pregel._call import identifier
  112. from langgraph.pregel._checkpoint import (
  113. channels_from_checkpoint,
  114. copy_checkpoint,
  115. create_checkpoint,
  116. empty_checkpoint,
  117. )
  118. from langgraph.pregel._draw import draw_graph
  119. from langgraph.pregel._io import map_input, read_channels
  120. from langgraph.pregel._loop import AsyncPregelLoop, SyncPregelLoop
  121. from langgraph.pregel._messages import StreamMessagesHandler
  122. from langgraph.pregel._read import DEFAULT_BOUND, PregelNode
  123. from langgraph.pregel._retry import RetryPolicy
  124. from langgraph.pregel._runner import PregelRunner
  125. from langgraph.pregel._utils import get_new_channel_versions
  126. from langgraph.pregel._validate import validate_graph, validate_keys
  127. from langgraph.pregel._write import ChannelWrite, ChannelWriteEntry
  128. from langgraph.pregel.debug import get_bolded_text, get_colored_text, tasks_w_writes
  129. from langgraph.pregel.protocol import PregelProtocol, StreamChunk, StreamProtocol
  130. from langgraph.runtime import DEFAULT_RUNTIME, Runtime
  131. from langgraph.types import (
  132. All,
  133. CachePolicy,
  134. Checkpointer,
  135. Command,
  136. Durability,
  137. Interrupt,
  138. Send,
  139. StateSnapshot,
  140. StateUpdate,
  141. StreamMode,
  142. )
  143. from langgraph.typing import ContextT, InputT, OutputT, StateT
  144. from langgraph.warnings import LangGraphDeprecatedSinceV10
  145. try:
  146. from langchain_core.tracers._streaming import _StreamingCallbackHandler
  147. except ImportError:
  148. _StreamingCallbackHandler = None # type: ignore
  149. __all__ = ("NodeBuilder", "Pregel")
  150. _WriteValue = Callable[[Input], Output] | Any
  151. class NodeBuilder:
  152. __slots__ = (
  153. "_channels",
  154. "_triggers",
  155. "_tags",
  156. "_metadata",
  157. "_writes",
  158. "_bound",
  159. "_retry_policy",
  160. "_cache_policy",
  161. )
  162. _channels: str | list[str]
  163. _triggers: list[str]
  164. _tags: list[str]
  165. _metadata: dict[str, Any]
  166. _writes: list[ChannelWriteEntry]
  167. _bound: Runnable
  168. _retry_policy: list[RetryPolicy]
  169. _cache_policy: CachePolicy | None
  170. def __init__(
  171. self,
  172. ) -> None:
  173. self._channels = []
  174. self._triggers = []
  175. self._tags = []
  176. self._metadata = {}
  177. self._writes = []
  178. self._bound = DEFAULT_BOUND
  179. self._retry_policy = []
  180. self._cache_policy = None
  181. def subscribe_only(
  182. self,
  183. channel: str,
  184. ) -> Self:
  185. """Subscribe to a single channel."""
  186. if not self._channels:
  187. self._channels = channel
  188. else:
  189. raise ValueError(
  190. "Cannot subscribe to single channels when other channels are already subscribed to"
  191. )
  192. self._triggers.append(channel)
  193. return self
  194. def subscribe_to(
  195. self,
  196. *channels: str,
  197. read: bool = True,
  198. ) -> Self:
  199. """Add channels to subscribe to.
  200. Node will be invoked when any of these channels are updated, with a dict of the
  201. channel values as input.
  202. Args:
  203. channels: Channel name(s) to subscribe to
  204. read: If `True`, the channels will be included in the input to the node.
  205. Otherwise, they will trigger the node without being sent in input.
  206. Returns:
  207. Self for chaining
  208. """
  209. if isinstance(self._channels, str):
  210. raise ValueError(
  211. "Cannot subscribe to channels when subscribed to a single channel"
  212. )
  213. if read:
  214. if not self._channels:
  215. self._channels = list(channels)
  216. else:
  217. self._channels.extend(channels)
  218. if isinstance(channels, str):
  219. self._triggers.append(channels)
  220. else:
  221. self._triggers.extend(channels)
  222. return self
  223. def read_from(
  224. self,
  225. *channels: str,
  226. ) -> Self:
  227. """Adds the specified channels to read from, without subscribing to them."""
  228. assert isinstance(self._channels, list), (
  229. "Cannot read additional channels when subscribed to single channels"
  230. )
  231. self._channels.extend(channels)
  232. return self
  233. def do(
  234. self,
  235. node: RunnableLike,
  236. ) -> Self:
  237. """Adds the specified node."""
  238. if self._bound is not DEFAULT_BOUND:
  239. self._bound = RunnableSeq(
  240. self._bound, coerce_to_runnable(node, name=None, trace=True)
  241. )
  242. else:
  243. self._bound = coerce_to_runnable(node, name=None, trace=True)
  244. return self
  245. def write_to(
  246. self,
  247. *channels: str | ChannelWriteEntry,
  248. **kwargs: _WriteValue,
  249. ) -> Self:
  250. """Add channel writes.
  251. Args:
  252. *channels: Channel names to write to.
  253. **kwargs: Channel name and value mappings.
  254. Returns:
  255. Self for chaining
  256. """
  257. self._writes.extend(
  258. ChannelWriteEntry(c) if isinstance(c, str) else c for c in channels
  259. )
  260. self._writes.extend(
  261. ChannelWriteEntry(k, mapper=v)
  262. if callable(v)
  263. else ChannelWriteEntry(k, value=v)
  264. for k, v in kwargs.items()
  265. )
  266. return self
  267. def meta(self, *tags: str, **metadata: Any) -> Self:
  268. """Add tags or metadata to the node."""
  269. self._tags.extend(tags)
  270. self._metadata.update(metadata)
  271. return self
  272. def add_retry_policies(self, *policies: RetryPolicy) -> Self:
  273. """Adds retry policies to the node."""
  274. self._retry_policy.extend(policies)
  275. return self
  276. def add_cache_policy(self, policy: CachePolicy) -> Self:
  277. """Adds cache policies to the node."""
  278. self._cache_policy = policy
  279. return self
  280. def build(self) -> PregelNode:
  281. """Builds the node."""
  282. return PregelNode(
  283. channels=self._channels,
  284. triggers=self._triggers,
  285. tags=self._tags,
  286. metadata=self._metadata,
  287. writers=[ChannelWrite(self._writes)],
  288. bound=self._bound,
  289. retry_policy=self._retry_policy,
  290. cache_policy=self._cache_policy,
  291. )
  292. class Pregel(
  293. PregelProtocol[StateT, ContextT, InputT, OutputT],
  294. Generic[StateT, ContextT, InputT, OutputT],
  295. ):
  296. """Pregel manages the runtime behavior for LangGraph applications.
  297. ## Overview
  298. Pregel combines [**actors**](https://en.wikipedia.org/wiki/Actor_model)
  299. and **channels** into a single application.
  300. **Actors** read data from channels and write data to channels.
  301. Pregel organizes the execution of the application into multiple steps,
  302. following the **Pregel Algorithm**/**Bulk Synchronous Parallel** model.
  303. Each step consists of three phases:
  304. - **Plan**: Determine which **actors** to execute in this step. For example,
  305. in the first step, select the **actors** that subscribe to the special
  306. **input** channels; in subsequent steps,
  307. select the **actors** that subscribe to channels updated in the previous step.
  308. - **Execution**: Execute all selected **actors** in parallel,
  309. until all complete, or one fails, or a timeout is reached. During this
  310. phase, channel updates are invisible to actors until the next step.
  311. - **Update**: Update the channels with the values written by the **actors**
  312. in this step.
  313. Repeat until no **actors** are selected for execution, or a maximum number of
  314. steps is reached.
  315. ## Actors
  316. An **actor** is a `PregelNode`.
  317. It subscribes to channels, reads data from them, and writes data to them.
  318. It can be thought of as an **actor** in the Pregel algorithm.
  319. `PregelNodes` implement LangChain's
  320. Runnable interface.
  321. ## Channels
  322. Channels are used to communicate between actors (`PregelNodes`).
  323. Each channel has a value type, an update type, and an update function – which
  324. takes a sequence of updates and
  325. modifies the stored value. Channels can be used to send data from one chain to
  326. another, or to send data from a chain to itself in a future step. LangGraph
  327. provides a number of built-in channels:
  328. ### Basic channels: LastValue and Topic
  329. - `LastValue`: The default channel, stores the last value sent to the channel,
  330. useful for input and output values, or for sending data from one step to the next
  331. - `Topic`: A configurable PubSub Topic, useful for sending multiple values
  332. between *actors*, or for accumulating output. Can be configured to deduplicate
  333. values, and/or to accumulate values over the course of multiple steps.
  334. ### Advanced channels: Context and BinaryOperatorAggregate
  335. - `Context`: exposes the value of a context manager, managing its lifecycle.
  336. Useful for accessing external resources that require setup and/or teardown. e.g.
  337. `client = Context(httpx.Client)`
  338. - `BinaryOperatorAggregate`: stores a persistent value, updated by applying
  339. a binary operator to the current value and each update
  340. sent to the channel, useful for computing aggregates over multiple steps. e.g.
  341. `total = BinaryOperatorAggregate(int, operator.add)`
  342. ## Examples
  343. Most users will interact with Pregel via a
  344. [StateGraph (Graph API)][langgraph.graph.StateGraph] or via an
  345. [entrypoint (Functional API)][langgraph.func.entrypoint].
  346. However, for **advanced** use cases, Pregel can be used directly. If you're
  347. not sure whether you need to use Pregel directly, then the answer is probably no
  348. - you should use the Graph API or Functional API instead. These are higher-level
  349. interfaces that will compile down to Pregel under the hood.
  350. Here are some examples to give you a sense of how it works:
  351. Example: Single node application
  352. ```python
  353. from langgraph.channels import EphemeralValue
  354. from langgraph.pregel import Pregel, NodeBuilder
  355. node1 = (
  356. NodeBuilder().subscribe_only("a")
  357. .do(lambda x: x + x)
  358. .write_to("b")
  359. )
  360. app = Pregel(
  361. nodes={"node1": node1},
  362. channels={
  363. "a": EphemeralValue(str),
  364. "b": EphemeralValue(str),
  365. },
  366. input_channels=["a"],
  367. output_channels=["b"],
  368. )
  369. app.invoke({"a": "foo"})
  370. ```
  371. ```con
  372. {'b': 'foofoo'}
  373. ```
  374. Example: Using multiple nodes and multiple output channels
  375. ```python
  376. from langgraph.channels import LastValue, EphemeralValue
  377. from langgraph.pregel import Pregel, NodeBuilder
  378. node1 = (
  379. NodeBuilder().subscribe_only("a")
  380. .do(lambda x: x + x)
  381. .write_to("b")
  382. )
  383. node2 = (
  384. NodeBuilder().subscribe_to("b")
  385. .do(lambda x: x["b"] + x["b"])
  386. .write_to("c")
  387. )
  388. app = Pregel(
  389. nodes={"node1": node1, "node2": node2},
  390. channels={
  391. "a": EphemeralValue(str),
  392. "b": LastValue(str),
  393. "c": EphemeralValue(str),
  394. },
  395. input_channels=["a"],
  396. output_channels=["b", "c"],
  397. )
  398. app.invoke({"a": "foo"})
  399. ```
  400. ```con
  401. {'b': 'foofoo', 'c': 'foofoofoofoo'}
  402. ```
  403. Example: Using a Topic channel
  404. ```python
  405. from langgraph.channels import LastValue, EphemeralValue, Topic
  406. from langgraph.pregel import Pregel, NodeBuilder
  407. node1 = (
  408. NodeBuilder().subscribe_only("a")
  409. .do(lambda x: x + x)
  410. .write_to("b", "c")
  411. )
  412. node2 = (
  413. NodeBuilder().subscribe_only("b")
  414. .do(lambda x: x + x)
  415. .write_to("c")
  416. )
  417. app = Pregel(
  418. nodes={"node1": node1, "node2": node2},
  419. channels={
  420. "a": EphemeralValue(str),
  421. "b": EphemeralValue(str),
  422. "c": Topic(str, accumulate=True),
  423. },
  424. input_channels=["a"],
  425. output_channels=["c"],
  426. )
  427. app.invoke({"a": "foo"})
  428. ```
  429. ```pycon
  430. {"c": ["foofoo", "foofoofoofoo"]}
  431. ```
  432. Example: Using a `BinaryOperatorAggregate` channel
  433. ```python
  434. from langgraph.channels import EphemeralValue, BinaryOperatorAggregate
  435. from langgraph.pregel import Pregel, NodeBuilder
  436. node1 = (
  437. NodeBuilder().subscribe_only("a")
  438. .do(lambda x: x + x)
  439. .write_to("b", "c")
  440. )
  441. node2 = (
  442. NodeBuilder().subscribe_only("b")
  443. .do(lambda x: x + x)
  444. .write_to("c")
  445. )
  446. def reducer(current, update):
  447. if current:
  448. return current + " | " + update
  449. else:
  450. return update
  451. app = Pregel(
  452. nodes={"node1": node1, "node2": node2},
  453. channels={
  454. "a": EphemeralValue(str),
  455. "b": EphemeralValue(str),
  456. "c": BinaryOperatorAggregate(str, operator=reducer),
  457. },
  458. input_channels=["a"],
  459. output_channels=["c"],
  460. )
  461. app.invoke({"a": "foo"})
  462. ```
  463. ```con
  464. {'c': 'foofoo | foofoofoofoo'}
  465. ```
  466. Example: Introducing a cycle
  467. This example demonstrates how to introduce a cycle in the graph, by having
  468. a chain write to a channel it subscribes to.
  469. Execution will continue until a `None` value is written to the channel.
  470. ```python
  471. from langgraph.channels import EphemeralValue
  472. from langgraph.pregel import Pregel, NodeBuilder, ChannelWriteEntry
  473. example_node = (
  474. NodeBuilder()
  475. .subscribe_only("value")
  476. .do(lambda x: x + x if len(x) < 10 else None)
  477. .write_to(ChannelWriteEntry(channel="value", skip_none=True))
  478. )
  479. app = Pregel(
  480. nodes={"example_node": example_node},
  481. channels={
  482. "value": EphemeralValue(str),
  483. },
  484. input_channels=["value"],
  485. output_channels=["value"],
  486. )
  487. app.invoke({"value": "a"})
  488. ```
  489. ```con
  490. {'value': 'aaaaaaaaaaaaaaaa'}
  491. ```
  492. """
  493. nodes: dict[str, PregelNode]
  494. channels: dict[str, BaseChannel | ManagedValueSpec]
  495. stream_mode: StreamMode = "values"
  496. """Mode to stream output, defaults to 'values'."""
  497. stream_eager: bool = False
  498. """Whether to force emitting stream events eagerly, automatically turned on
  499. for stream_mode "messages" and "custom"."""
  500. output_channels: str | Sequence[str]
  501. stream_channels: str | Sequence[str] | None = None
  502. """Channels to stream, defaults to all channels not in reserved channels"""
  503. interrupt_after_nodes: All | Sequence[str]
  504. interrupt_before_nodes: All | Sequence[str]
  505. input_channels: str | Sequence[str]
  506. step_timeout: float | None = None
  507. """Maximum time to wait for a step to complete, in seconds."""
  508. debug: bool
  509. """Whether to print debug information during execution."""
  510. checkpointer: Checkpointer = None
  511. """`Checkpointer` used to save and load graph state."""
  512. store: BaseStore | None = None
  513. """Memory store to use for SharedValues."""
  514. cache: BaseCache | None = None
  515. """Cache to use for storing node results."""
  516. retry_policy: Sequence[RetryPolicy] = ()
  517. """Retry policies to use when running tasks. Empty set disables retries."""
  518. cache_policy: CachePolicy | None = None
  519. """Cache policy to use for all nodes. Can be overridden by individual nodes."""
  520. context_schema: type[ContextT] | None = None
  521. """Specifies the schema for the context object that will be passed to the workflow."""
  522. config: RunnableConfig | None = None
  523. name: str = "LangGraph"
  524. trigger_to_nodes: Mapping[str, Sequence[str]]
  525. def __init__(
  526. self,
  527. *,
  528. nodes: dict[str, PregelNode | NodeBuilder],
  529. channels: dict[str, BaseChannel | ManagedValueSpec] | None,
  530. auto_validate: bool = True,
  531. stream_mode: StreamMode = "values",
  532. stream_eager: bool = False,
  533. output_channels: str | Sequence[str],
  534. stream_channels: str | Sequence[str] | None = None,
  535. interrupt_after_nodes: All | Sequence[str] = (),
  536. interrupt_before_nodes: All | Sequence[str] = (),
  537. input_channels: str | Sequence[str],
  538. step_timeout: float | None = None,
  539. debug: bool | None = None,
  540. checkpointer: BaseCheckpointSaver | None = None,
  541. store: BaseStore | None = None,
  542. cache: BaseCache | None = None,
  543. retry_policy: RetryPolicy | Sequence[RetryPolicy] = (),
  544. cache_policy: CachePolicy | None = None,
  545. context_schema: type[ContextT] | None = None,
  546. config: RunnableConfig | None = None,
  547. trigger_to_nodes: Mapping[str, Sequence[str]] | None = None,
  548. name: str = "LangGraph",
  549. **deprecated_kwargs: Unpack[DeprecatedKwargs],
  550. ) -> None:
  551. if (
  552. config_type := deprecated_kwargs.get("config_type", MISSING)
  553. ) is not MISSING:
  554. warnings.warn(
  555. "`config_type` is deprecated and will be removed. Please use `context_schema` instead.",
  556. category=LangGraphDeprecatedSinceV10,
  557. stacklevel=2,
  558. )
  559. if context_schema is None:
  560. context_schema = cast(type[ContextT], config_type)
  561. self.nodes = {
  562. k: v.build() if isinstance(v, NodeBuilder) else v for k, v in nodes.items()
  563. }
  564. self.channels = channels or {}
  565. if TASKS in self.channels and not isinstance(self.channels[TASKS], Topic):
  566. raise ValueError(
  567. f"Channel '{TASKS}' is reserved and cannot be used in the graph."
  568. )
  569. else:
  570. self.channels[TASKS] = Topic(Send, accumulate=False)
  571. self.stream_mode = stream_mode
  572. self.stream_eager = stream_eager
  573. self.output_channels = output_channels
  574. self.stream_channels = stream_channels
  575. self.interrupt_after_nodes = interrupt_after_nodes
  576. self.interrupt_before_nodes = interrupt_before_nodes
  577. self.input_channels = input_channels
  578. self.step_timeout = step_timeout
  579. self.debug = debug if debug is not None else get_debug()
  580. self.checkpointer = checkpointer
  581. self.store = store
  582. self.cache = cache
  583. self.retry_policy = (
  584. (retry_policy,) if isinstance(retry_policy, RetryPolicy) else retry_policy
  585. )
  586. self.cache_policy = cache_policy
  587. self.context_schema = context_schema
  588. self.config = config
  589. self.trigger_to_nodes = trigger_to_nodes or {}
  590. self.name = name
  591. if auto_validate:
  592. self.validate()
  593. def get_graph(
  594. self, config: RunnableConfig | None = None, *, xray: int | bool = False
  595. ) -> Graph:
  596. """Return a drawable representation of the computation graph."""
  597. # gather subgraphs
  598. if xray:
  599. subgraphs = {
  600. k: v.get_graph(
  601. config,
  602. xray=xray if isinstance(xray, bool) or xray <= 0 else xray - 1,
  603. )
  604. for k, v in self.get_subgraphs()
  605. }
  606. else:
  607. subgraphs = {}
  608. return draw_graph(
  609. merge_configs(self.config, config),
  610. nodes=self.nodes,
  611. specs=self.channels,
  612. input_channels=self.input_channels,
  613. interrupt_after_nodes=self.interrupt_after_nodes,
  614. interrupt_before_nodes=self.interrupt_before_nodes,
  615. trigger_to_nodes=self.trigger_to_nodes,
  616. checkpointer=self.checkpointer,
  617. subgraphs=subgraphs,
  618. )
  619. async def aget_graph(
  620. self, config: RunnableConfig | None = None, *, xray: int | bool = False
  621. ) -> Graph:
  622. """Return a drawable representation of the computation graph."""
  623. # gather subgraphs
  624. if xray:
  625. subpregels: dict[str, PregelProtocol] = {
  626. k: v async for k, v in self.aget_subgraphs()
  627. }
  628. subgraphs = {
  629. k: v
  630. for k, v in zip(
  631. subpregels,
  632. await asyncio.gather(
  633. *(
  634. p.aget_graph(
  635. config,
  636. xray=xray
  637. if isinstance(xray, bool) or xray <= 0
  638. else xray - 1,
  639. )
  640. for p in subpregels.values()
  641. )
  642. ),
  643. )
  644. }
  645. else:
  646. subgraphs = {}
  647. return draw_graph(
  648. merge_configs(self.config, config),
  649. nodes=self.nodes,
  650. specs=self.channels,
  651. input_channels=self.input_channels,
  652. interrupt_after_nodes=self.interrupt_after_nodes,
  653. interrupt_before_nodes=self.interrupt_before_nodes,
  654. trigger_to_nodes=self.trigger_to_nodes,
  655. checkpointer=self.checkpointer,
  656. subgraphs=subgraphs,
  657. )
  658. def _repr_mimebundle_(self, **kwargs: Any) -> dict[str, Any]:
  659. """Mime bundle used by Jupyter to display the graph"""
  660. return {
  661. "text/plain": repr(self),
  662. "image/png": self.get_graph().draw_mermaid_png(),
  663. }
  664. def copy(self, update: dict[str, Any] | None = None) -> Self:
  665. attrs = {k: v for k, v in self.__dict__.items() if k != "__orig_class__"}
  666. attrs.update(update or {})
  667. return self.__class__(**attrs)
  668. def with_config(self, config: RunnableConfig | None = None, **kwargs: Any) -> Self:
  669. """Create a copy of the Pregel object with an updated config."""
  670. return self.copy(
  671. {"config": merge_configs(self.config, config, cast(RunnableConfig, kwargs))}
  672. )
  673. def validate(self) -> Self:
  674. validate_graph(
  675. self.nodes,
  676. {k: v for k, v in self.channels.items() if isinstance(v, BaseChannel)},
  677. {k: v for k, v in self.channels.items() if not isinstance(v, BaseChannel)},
  678. self.input_channels,
  679. self.output_channels,
  680. self.stream_channels,
  681. self.interrupt_after_nodes,
  682. self.interrupt_before_nodes,
  683. )
  684. self.trigger_to_nodes = _trigger_to_nodes(self.nodes)
  685. return self
  686. @deprecated(
  687. "`config_schema` is deprecated. Use `get_context_jsonschema` for the relevant schema instead.",
  688. category=None,
  689. )
  690. def config_schema(self, *, include: Sequence[str] | None = None) -> type[BaseModel]:
  691. warnings.warn(
  692. "`config_schema` is deprecated. Use `get_context_jsonschema` for the relevant schema instead.",
  693. category=LangGraphDeprecatedSinceV10,
  694. stacklevel=2,
  695. )
  696. include = include or []
  697. fields = {
  698. **(
  699. {"configurable": (self.context_schema, None)}
  700. if self.context_schema
  701. else {}
  702. ),
  703. **{
  704. field_name: (field_type, None)
  705. for field_name, field_type in get_type_hints(RunnableConfig).items()
  706. if field_name in [i for i in include if i != "configurable"]
  707. },
  708. }
  709. return create_model(self.get_name("Config"), field_definitions=fields)
  710. @deprecated(
  711. "`get_config_jsonschema` is deprecated. Use `get_context_jsonschema` instead.",
  712. category=None,
  713. )
  714. def get_config_jsonschema(
  715. self, *, include: Sequence[str] | None = None
  716. ) -> dict[str, Any]:
  717. warnings.warn(
  718. "`get_config_jsonschema` is deprecated. Use `get_context_jsonschema` instead.",
  719. category=LangGraphDeprecatedSinceV10,
  720. stacklevel=2,
  721. )
  722. with warnings.catch_warnings():
  723. warnings.filterwarnings("ignore", category=LangGraphDeprecatedSinceV10)
  724. schema = self.config_schema(include=include)
  725. return schema.model_json_schema()
  726. def get_context_jsonschema(self) -> dict[str, Any] | None:
  727. if (context_schema := self.context_schema) is None:
  728. return None
  729. if isclass(context_schema) and issubclass(context_schema, BaseModel):
  730. return context_schema.model_json_schema()
  731. elif is_typeddict(context_schema) or is_dataclass(context_schema):
  732. return TypeAdapter(context_schema).json_schema()
  733. else:
  734. raise ValueError(
  735. f"Invalid context schema type: {context_schema}. Must be a BaseModel, TypedDict or dataclass."
  736. )
  737. @property
  738. def InputType(self) -> Any:
  739. if isinstance(self.input_channels, str):
  740. channel = self.channels[self.input_channels]
  741. if isinstance(channel, BaseChannel):
  742. return channel.UpdateType
  743. def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
  744. config = merge_configs(self.config, config)
  745. if isinstance(self.input_channels, str):
  746. return super().get_input_schema(config)
  747. else:
  748. return create_model(
  749. self.get_name("Input"),
  750. field_definitions={
  751. k: (c.UpdateType, None)
  752. for k in self.input_channels or self.channels.keys()
  753. if (c := self.channels[k]) and isinstance(c, BaseChannel)
  754. },
  755. )
  756. def get_input_jsonschema(
  757. self, config: RunnableConfig | None = None
  758. ) -> dict[str, Any]:
  759. schema = self.get_input_schema(config)
  760. return schema.model_json_schema()
  761. @property
  762. def OutputType(self) -> Any:
  763. if isinstance(self.output_channels, str):
  764. channel = self.channels[self.output_channels]
  765. if isinstance(channel, BaseChannel):
  766. return channel.ValueType
  767. def get_output_schema(
  768. self, config: RunnableConfig | None = None
  769. ) -> type[BaseModel]:
  770. config = merge_configs(self.config, config)
  771. if isinstance(self.output_channels, str):
  772. return super().get_output_schema(config)
  773. else:
  774. return create_model(
  775. self.get_name("Output"),
  776. field_definitions={
  777. k: (c.ValueType, None)
  778. for k in self.output_channels
  779. if (c := self.channels[k]) and isinstance(c, BaseChannel)
  780. },
  781. )
  782. def get_output_jsonschema(
  783. self, config: RunnableConfig | None = None
  784. ) -> dict[str, Any]:
  785. schema = self.get_output_schema(config)
  786. return schema.model_json_schema()
  787. @property
  788. def stream_channels_list(self) -> Sequence[str]:
  789. stream_channels = self.stream_channels_asis
  790. return (
  791. [stream_channels] if isinstance(stream_channels, str) else stream_channels
  792. )
  793. @property
  794. def stream_channels_asis(self) -> str | Sequence[str]:
  795. return self.stream_channels or [
  796. k for k in self.channels if isinstance(self.channels[k], BaseChannel)
  797. ]
  798. def get_subgraphs(
  799. self, *, namespace: str | None = None, recurse: bool = False
  800. ) -> Iterator[tuple[str, PregelProtocol]]:
  801. """Get the subgraphs of the graph.
  802. Args:
  803. namespace: The namespace to filter the subgraphs by.
  804. recurse: Whether to recurse into the subgraphs.
  805. If `False`, only the immediate subgraphs will be returned.
  806. Returns:
  807. An iterator of the `(namespace, subgraph)` pairs.
  808. """
  809. for name, node in self.nodes.items():
  810. # filter by prefix
  811. if namespace is not None:
  812. if not namespace.startswith(name):
  813. continue
  814. # find the subgraph, if any
  815. graph = node.subgraphs[0] if node.subgraphs else None
  816. # if found, yield recursively
  817. if graph:
  818. if name == namespace:
  819. yield name, graph
  820. return # we found it, stop searching
  821. if namespace is None:
  822. yield name, graph
  823. if recurse and isinstance(graph, Pregel):
  824. if namespace is not None:
  825. namespace = namespace[len(name) + 1 :]
  826. yield from (
  827. (f"{name}{NS_SEP}{n}", s)
  828. for n, s in graph.get_subgraphs(
  829. namespace=namespace, recurse=recurse
  830. )
  831. )
  832. async def aget_subgraphs(
  833. self, *, namespace: str | None = None, recurse: bool = False
  834. ) -> AsyncIterator[tuple[str, PregelProtocol]]:
  835. """Get the subgraphs of the graph.
  836. Args:
  837. namespace: The namespace to filter the subgraphs by.
  838. recurse: Whether to recurse into the subgraphs.
  839. If `False`, only the immediate subgraphs will be returned.
  840. Returns:
  841. An iterator of the `(namespace, subgraph)` pairs.
  842. """
  843. for name, node in self.get_subgraphs(namespace=namespace, recurse=recurse):
  844. yield name, node
  845. def _migrate_checkpoint(self, checkpoint: Checkpoint) -> None:
  846. """Migrate a saved checkpoint to new channel layout."""
  847. if checkpoint["v"] < 4 and checkpoint.get("pending_sends"):
  848. pending_sends: list[Send] = checkpoint.pop("pending_sends")
  849. checkpoint["channel_values"][TASKS] = pending_sends
  850. checkpoint["channel_versions"][TASKS] = max(
  851. checkpoint["channel_versions"].values()
  852. )
  853. def _prepare_state_snapshot(
  854. self,
  855. config: RunnableConfig,
  856. saved: CheckpointTuple | None,
  857. recurse: BaseCheckpointSaver | None = None,
  858. apply_pending_writes: bool = False,
  859. ) -> StateSnapshot:
  860. if not saved:
  861. return StateSnapshot(
  862. values={},
  863. next=(),
  864. config=config,
  865. metadata=None,
  866. created_at=None,
  867. parent_config=None,
  868. tasks=(),
  869. interrupts=(),
  870. )
  871. # migrate checkpoint if needed
  872. self._migrate_checkpoint(saved.checkpoint)
  873. step = saved.metadata.get("step", -1) + 1
  874. stop = step + 2
  875. channels, managed = channels_from_checkpoint(
  876. self.channels,
  877. saved.checkpoint,
  878. )
  879. # tasks for this checkpoint
  880. next_tasks = prepare_next_tasks(
  881. saved.checkpoint,
  882. saved.pending_writes or [],
  883. self.nodes,
  884. channels,
  885. managed,
  886. saved.config,
  887. step,
  888. stop,
  889. for_execution=True,
  890. store=self.store,
  891. checkpointer=(
  892. self.checkpointer
  893. if isinstance(self.checkpointer, BaseCheckpointSaver)
  894. else None
  895. ),
  896. manager=None,
  897. )
  898. # get the subgraphs
  899. subgraphs = dict(self.get_subgraphs())
  900. parent_ns = saved.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
  901. task_states: dict[str, RunnableConfig | StateSnapshot] = {}
  902. for task in next_tasks.values():
  903. if task.name not in subgraphs:
  904. continue
  905. # assemble checkpoint_ns for this task
  906. task_ns = f"{task.name}{NS_END}{task.id}"
  907. if parent_ns:
  908. task_ns = f"{parent_ns}{NS_SEP}{task_ns}"
  909. if not recurse:
  910. # set config as signal that subgraph checkpoints exist
  911. config = {
  912. CONF: {
  913. "thread_id": saved.config[CONF]["thread_id"],
  914. CONFIG_KEY_CHECKPOINT_NS: task_ns,
  915. }
  916. }
  917. task_states[task.id] = config
  918. else:
  919. # get the state of the subgraph
  920. config = {
  921. CONF: {
  922. CONFIG_KEY_CHECKPOINTER: recurse,
  923. "thread_id": saved.config[CONF]["thread_id"],
  924. CONFIG_KEY_CHECKPOINT_NS: task_ns,
  925. }
  926. }
  927. task_states[task.id] = subgraphs[task.name].get_state(
  928. config, subgraphs=True
  929. )
  930. # apply pending writes
  931. if null_writes := [
  932. w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID
  933. ]:
  934. apply_writes(
  935. saved.checkpoint,
  936. channels,
  937. [PregelTaskWrites((), INPUT, null_writes, [])],
  938. None,
  939. self.trigger_to_nodes,
  940. )
  941. if apply_pending_writes and saved.pending_writes:
  942. for tid, k, v in saved.pending_writes:
  943. if k in (ERROR, INTERRUPT):
  944. continue
  945. if tid not in next_tasks:
  946. continue
  947. next_tasks[tid].writes.append((k, v))
  948. if tasks := [t for t in next_tasks.values() if t.writes]:
  949. apply_writes(
  950. saved.checkpoint, channels, tasks, None, self.trigger_to_nodes
  951. )
  952. tasks_with_writes = tasks_w_writes(
  953. next_tasks.values(),
  954. saved.pending_writes,
  955. task_states,
  956. self.stream_channels_asis,
  957. )
  958. # assemble the state snapshot
  959. return StateSnapshot(
  960. read_channels(channels, self.stream_channels_asis),
  961. tuple(t.name for t in next_tasks.values() if not t.writes),
  962. patch_checkpoint_map(saved.config, saved.metadata),
  963. saved.metadata,
  964. saved.checkpoint["ts"],
  965. patch_checkpoint_map(saved.parent_config, saved.metadata),
  966. tasks_with_writes,
  967. tuple([i for task in tasks_with_writes for i in task.interrupts]),
  968. )
  969. async def _aprepare_state_snapshot(
  970. self,
  971. config: RunnableConfig,
  972. saved: CheckpointTuple | None,
  973. recurse: BaseCheckpointSaver | None = None,
  974. apply_pending_writes: bool = False,
  975. ) -> StateSnapshot:
  976. if not saved:
  977. return StateSnapshot(
  978. values={},
  979. next=(),
  980. config=config,
  981. metadata=None,
  982. created_at=None,
  983. parent_config=None,
  984. tasks=(),
  985. interrupts=(),
  986. )
  987. # migrate checkpoint if needed
  988. self._migrate_checkpoint(saved.checkpoint)
  989. step = saved.metadata.get("step", -1) + 1
  990. stop = step + 2
  991. channels, managed = channels_from_checkpoint(
  992. self.channels,
  993. saved.checkpoint,
  994. )
  995. # tasks for this checkpoint
  996. next_tasks = prepare_next_tasks(
  997. saved.checkpoint,
  998. saved.pending_writes or [],
  999. self.nodes,
  1000. channels,
  1001. managed,
  1002. saved.config,
  1003. step,
  1004. stop,
  1005. for_execution=True,
  1006. store=self.store,
  1007. checkpointer=(
  1008. self.checkpointer
  1009. if isinstance(self.checkpointer, BaseCheckpointSaver)
  1010. else None
  1011. ),
  1012. manager=None,
  1013. )
  1014. # get the subgraphs
  1015. subgraphs = {n: g async for n, g in self.aget_subgraphs()}
  1016. parent_ns = saved.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
  1017. task_states: dict[str, RunnableConfig | StateSnapshot] = {}
  1018. for task in next_tasks.values():
  1019. if task.name not in subgraphs:
  1020. continue
  1021. # assemble checkpoint_ns for this task
  1022. task_ns = f"{task.name}{NS_END}{task.id}"
  1023. if parent_ns:
  1024. task_ns = f"{parent_ns}{NS_SEP}{task_ns}"
  1025. if not recurse:
  1026. # set config as signal that subgraph checkpoints exist
  1027. config = {
  1028. CONF: {
  1029. "thread_id": saved.config[CONF]["thread_id"],
  1030. CONFIG_KEY_CHECKPOINT_NS: task_ns,
  1031. }
  1032. }
  1033. task_states[task.id] = config
  1034. else:
  1035. # get the state of the subgraph
  1036. config = {
  1037. CONF: {
  1038. CONFIG_KEY_CHECKPOINTER: recurse,
  1039. "thread_id": saved.config[CONF]["thread_id"],
  1040. CONFIG_KEY_CHECKPOINT_NS: task_ns,
  1041. }
  1042. }
  1043. task_states[task.id] = await subgraphs[task.name].aget_state(
  1044. config, subgraphs=True
  1045. )
  1046. # apply pending writes
  1047. if null_writes := [
  1048. w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID
  1049. ]:
  1050. apply_writes(
  1051. saved.checkpoint,
  1052. channels,
  1053. [PregelTaskWrites((), INPUT, null_writes, [])],
  1054. None,
  1055. self.trigger_to_nodes,
  1056. )
  1057. if apply_pending_writes and saved.pending_writes:
  1058. for tid, k, v in saved.pending_writes:
  1059. if k in (ERROR, INTERRUPT):
  1060. continue
  1061. if tid not in next_tasks:
  1062. continue
  1063. next_tasks[tid].writes.append((k, v))
  1064. if tasks := [t for t in next_tasks.values() if t.writes]:
  1065. apply_writes(
  1066. saved.checkpoint, channels, tasks, None, self.trigger_to_nodes
  1067. )
  1068. tasks_with_writes = tasks_w_writes(
  1069. next_tasks.values(),
  1070. saved.pending_writes,
  1071. task_states,
  1072. self.stream_channels_asis,
  1073. )
  1074. # assemble the state snapshot
  1075. return StateSnapshot(
  1076. read_channels(channels, self.stream_channels_asis),
  1077. tuple(t.name for t in next_tasks.values() if not t.writes),
  1078. patch_checkpoint_map(saved.config, saved.metadata),
  1079. saved.metadata,
  1080. saved.checkpoint["ts"],
  1081. patch_checkpoint_map(saved.parent_config, saved.metadata),
  1082. tasks_with_writes,
  1083. tuple([i for task in tasks_with_writes for i in task.interrupts]),
  1084. )
  1085. def get_state(
  1086. self, config: RunnableConfig, *, subgraphs: bool = False
  1087. ) -> StateSnapshot:
  1088. """Get the current state of the graph."""
  1089. checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get(
  1090. CONFIG_KEY_CHECKPOINTER, self.checkpointer
  1091. )
  1092. if not checkpointer:
  1093. raise ValueError("No checkpointer set")
  1094. if (
  1095. checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
  1096. ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
  1097. # remove task_ids from checkpoint_ns
  1098. recast = recast_checkpoint_ns(checkpoint_ns)
  1099. # find the subgraph with the matching name
  1100. for _, pregel in self.get_subgraphs(namespace=recast, recurse=True):
  1101. return pregel.get_state(
  1102. patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
  1103. subgraphs=subgraphs,
  1104. )
  1105. else:
  1106. raise ValueError(f"Subgraph {recast} not found")
  1107. config = merge_configs(self.config, config) if self.config else config
  1108. if self.checkpointer is True:
  1109. ns = cast(str, config[CONF][CONFIG_KEY_CHECKPOINT_NS])
  1110. config = merge_configs(
  1111. config, {CONF: {CONFIG_KEY_CHECKPOINT_NS: recast_checkpoint_ns(ns)}}
  1112. )
  1113. thread_id = config[CONF][CONFIG_KEY_THREAD_ID]
  1114. if not isinstance(thread_id, str):
  1115. config[CONF][CONFIG_KEY_THREAD_ID] = str(thread_id)
  1116. saved = checkpointer.get_tuple(config)
  1117. return self._prepare_state_snapshot(
  1118. config,
  1119. saved,
  1120. recurse=checkpointer if subgraphs else None,
  1121. apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF],
  1122. )
  1123. async def aget_state(
  1124. self, config: RunnableConfig, *, subgraphs: bool = False
  1125. ) -> StateSnapshot:
  1126. """Get the current state of the graph."""
  1127. checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get(
  1128. CONFIG_KEY_CHECKPOINTER, self.checkpointer
  1129. )
  1130. if not checkpointer:
  1131. raise ValueError("No checkpointer set")
  1132. if (
  1133. checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
  1134. ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
  1135. # remove task_ids from checkpoint_ns
  1136. recast = recast_checkpoint_ns(checkpoint_ns)
  1137. # find the subgraph with the matching name
  1138. async for _, pregel in self.aget_subgraphs(namespace=recast, recurse=True):
  1139. return await pregel.aget_state(
  1140. patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
  1141. subgraphs=subgraphs,
  1142. )
  1143. else:
  1144. raise ValueError(f"Subgraph {recast} not found")
  1145. config = merge_configs(self.config, config) if self.config else config
  1146. if self.checkpointer is True:
  1147. ns = cast(str, config[CONF][CONFIG_KEY_CHECKPOINT_NS])
  1148. config = merge_configs(
  1149. config, {CONF: {CONFIG_KEY_CHECKPOINT_NS: recast_checkpoint_ns(ns)}}
  1150. )
  1151. thread_id = config[CONF][CONFIG_KEY_THREAD_ID]
  1152. if not isinstance(thread_id, str):
  1153. config[CONF][CONFIG_KEY_THREAD_ID] = str(thread_id)
  1154. saved = await checkpointer.aget_tuple(config)
  1155. return await self._aprepare_state_snapshot(
  1156. config,
  1157. saved,
  1158. recurse=checkpointer if subgraphs else None,
  1159. apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF],
  1160. )
  1161. def get_state_history(
  1162. self,
  1163. config: RunnableConfig,
  1164. *,
  1165. filter: dict[str, Any] | None = None,
  1166. before: RunnableConfig | None = None,
  1167. limit: int | None = None,
  1168. ) -> Iterator[StateSnapshot]:
  1169. """Get the history of the state of the graph."""
  1170. config = ensure_config(config)
  1171. checkpointer: BaseCheckpointSaver | None = config[CONF].get(
  1172. CONFIG_KEY_CHECKPOINTER, self.checkpointer
  1173. )
  1174. if not checkpointer:
  1175. raise ValueError("No checkpointer set")
  1176. if (
  1177. checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
  1178. ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
  1179. # remove task_ids from checkpoint_ns
  1180. recast = recast_checkpoint_ns(checkpoint_ns)
  1181. # find the subgraph with the matching name
  1182. for _, pregel in self.get_subgraphs(namespace=recast, recurse=True):
  1183. yield from pregel.get_state_history(
  1184. patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
  1185. filter=filter,
  1186. before=before,
  1187. limit=limit,
  1188. )
  1189. return
  1190. else:
  1191. raise ValueError(f"Subgraph {recast} not found")
  1192. config = merge_configs(
  1193. self.config,
  1194. config,
  1195. {
  1196. CONF: {
  1197. CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns,
  1198. CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID]),
  1199. }
  1200. },
  1201. )
  1202. # eagerly consume list() to avoid holding up the db cursor
  1203. for checkpoint_tuple in list(
  1204. checkpointer.list(config, before=before, limit=limit, filter=filter)
  1205. ):
  1206. yield self._prepare_state_snapshot(
  1207. checkpoint_tuple.config, checkpoint_tuple
  1208. )
  1209. async def aget_state_history(
  1210. self,
  1211. config: RunnableConfig,
  1212. *,
  1213. filter: dict[str, Any] | None = None,
  1214. before: RunnableConfig | None = None,
  1215. limit: int | None = None,
  1216. ) -> AsyncIterator[StateSnapshot]:
  1217. """Asynchronously get the history of the state of the graph."""
  1218. config = ensure_config(config)
  1219. checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get(
  1220. CONFIG_KEY_CHECKPOINTER, self.checkpointer
  1221. )
  1222. if not checkpointer:
  1223. raise ValueError("No checkpointer set")
  1224. if (
  1225. checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
  1226. ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
  1227. # remove task_ids from checkpoint_ns
  1228. recast = recast_checkpoint_ns(checkpoint_ns)
  1229. # find the subgraph with the matching name
  1230. async for _, pregel in self.aget_subgraphs(namespace=recast, recurse=True):
  1231. async for state in pregel.aget_state_history(
  1232. patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
  1233. filter=filter,
  1234. before=before,
  1235. limit=limit,
  1236. ):
  1237. yield state
  1238. return
  1239. else:
  1240. raise ValueError(f"Subgraph {recast} not found")
  1241. config = merge_configs(
  1242. self.config,
  1243. config,
  1244. {
  1245. CONF: {
  1246. CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns,
  1247. CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID]),
  1248. }
  1249. },
  1250. )
  1251. # eagerly consume list() to avoid holding up the db cursor
  1252. for checkpoint_tuple in [
  1253. c
  1254. async for c in checkpointer.alist(
  1255. config, before=before, limit=limit, filter=filter
  1256. )
  1257. ]:
  1258. yield await self._aprepare_state_snapshot(
  1259. checkpoint_tuple.config, checkpoint_tuple
  1260. )
  1261. def bulk_update_state(
  1262. self,
  1263. config: RunnableConfig,
  1264. supersteps: Sequence[Sequence[StateUpdate]],
  1265. ) -> RunnableConfig:
  1266. """Apply updates to the graph state in bulk. Requires a checkpointer to be set.
  1267. Args:
  1268. config: The config to apply the updates to.
  1269. supersteps: A list of supersteps, each including a list of updates to apply sequentially to a graph state.
  1270. Each update is a tuple of the form `(values, as_node, task_id)` where `task_id` is optional.
  1271. Raises:
  1272. ValueError: If no checkpointer is set or no updates are provided.
  1273. InvalidUpdateError: If an invalid update is provided.
  1274. Returns:
  1275. RunnableConfig: The updated config.
  1276. """
  1277. checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get(
  1278. CONFIG_KEY_CHECKPOINTER, self.checkpointer
  1279. )
  1280. if not checkpointer:
  1281. raise ValueError("No checkpointer set")
  1282. if len(supersteps) == 0:
  1283. raise ValueError("No supersteps provided")
  1284. if any(len(u) == 0 for u in supersteps):
  1285. raise ValueError("No updates provided")
  1286. # delegate to subgraph
  1287. if (
  1288. checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
  1289. ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
  1290. # remove task_ids from checkpoint_ns
  1291. recast = recast_checkpoint_ns(checkpoint_ns)
  1292. # find the subgraph with the matching name
  1293. for _, pregel in self.get_subgraphs(namespace=recast, recurse=True):
  1294. return pregel.bulk_update_state(
  1295. patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
  1296. supersteps,
  1297. )
  1298. else:
  1299. raise ValueError(f"Subgraph {recast} not found")
  1300. def perform_superstep(
  1301. input_config: RunnableConfig, updates: Sequence[StateUpdate]
  1302. ) -> RunnableConfig:
  1303. # get last checkpoint
  1304. config = ensure_config(self.config, input_config)
  1305. saved = checkpointer.get_tuple(config)
  1306. if saved is not None:
  1307. self._migrate_checkpoint(saved.checkpoint)
  1308. checkpoint = (
  1309. copy_checkpoint(saved.checkpoint) if saved else empty_checkpoint()
  1310. )
  1311. checkpoint_previous_versions = (
  1312. saved.checkpoint["channel_versions"].copy() if saved else {}
  1313. )
  1314. step = saved.metadata.get("step", -1) if saved else -1
  1315. # merge configurable fields with previous checkpoint config
  1316. checkpoint_config = patch_configurable(
  1317. config,
  1318. {
  1319. CONFIG_KEY_CHECKPOINT_NS: config[CONF].get(
  1320. CONFIG_KEY_CHECKPOINT_NS, ""
  1321. )
  1322. },
  1323. )
  1324. if saved:
  1325. checkpoint_config = patch_configurable(config, saved.config[CONF])
  1326. channels, managed = channels_from_checkpoint(
  1327. self.channels,
  1328. checkpoint,
  1329. )
  1330. values, as_node = updates[0][:2]
  1331. # no values as END, just clear all tasks
  1332. if values is None and as_node == END:
  1333. if len(updates) > 1:
  1334. raise InvalidUpdateError(
  1335. "Cannot apply multiple updates when clearing state"
  1336. )
  1337. if saved is not None:
  1338. # tasks for this checkpoint
  1339. next_tasks = prepare_next_tasks(
  1340. checkpoint,
  1341. saved.pending_writes or [],
  1342. self.nodes,
  1343. channels,
  1344. managed,
  1345. saved.config,
  1346. step + 1,
  1347. step + 3,
  1348. for_execution=True,
  1349. store=self.store,
  1350. checkpointer=checkpointer,
  1351. manager=None,
  1352. )
  1353. # apply null writes
  1354. if null_writes := [
  1355. w[1:]
  1356. for w in saved.pending_writes or []
  1357. if w[0] == NULL_TASK_ID
  1358. ]:
  1359. apply_writes(
  1360. checkpoint,
  1361. channels,
  1362. [PregelTaskWrites((), INPUT, null_writes, [])],
  1363. checkpointer.get_next_version,
  1364. self.trigger_to_nodes,
  1365. )
  1366. # apply writes from tasks that already ran
  1367. for tid, k, v in saved.pending_writes or []:
  1368. if k in (ERROR, INTERRUPT):
  1369. continue
  1370. if tid not in next_tasks:
  1371. continue
  1372. next_tasks[tid].writes.append((k, v))
  1373. # clear all current tasks
  1374. apply_writes(
  1375. checkpoint,
  1376. channels,
  1377. next_tasks.values(),
  1378. checkpointer.get_next_version,
  1379. self.trigger_to_nodes,
  1380. )
  1381. # save checkpoint
  1382. next_config = checkpointer.put(
  1383. checkpoint_config,
  1384. create_checkpoint(checkpoint, channels, step),
  1385. {
  1386. "source": "update",
  1387. "step": step + 1,
  1388. "parents": saved.metadata.get("parents", {}) if saved else {},
  1389. },
  1390. get_new_channel_versions(
  1391. checkpoint_previous_versions,
  1392. checkpoint["channel_versions"],
  1393. ),
  1394. )
  1395. return patch_checkpoint_map(
  1396. next_config, saved.metadata if saved else None
  1397. )
  1398. # act as an input
  1399. if as_node == INPUT:
  1400. if len(updates) > 1:
  1401. raise InvalidUpdateError(
  1402. "Cannot apply multiple updates when updating as input"
  1403. )
  1404. if input_writes := deque(map_input(self.input_channels, values)):
  1405. apply_writes(
  1406. checkpoint,
  1407. channels,
  1408. [PregelTaskWrites((), INPUT, input_writes, [])],
  1409. checkpointer.get_next_version,
  1410. self.trigger_to_nodes,
  1411. )
  1412. # apply input write to channels
  1413. next_step = (
  1414. step + 1
  1415. if saved and saved.metadata.get("step") is not None
  1416. else -1
  1417. )
  1418. next_config = checkpointer.put(
  1419. checkpoint_config,
  1420. create_checkpoint(checkpoint, channels, next_step),
  1421. {
  1422. "source": "input",
  1423. "step": next_step,
  1424. "parents": saved.metadata.get("parents", {})
  1425. if saved
  1426. else {},
  1427. },
  1428. get_new_channel_versions(
  1429. checkpoint_previous_versions,
  1430. checkpoint["channel_versions"],
  1431. ),
  1432. )
  1433. # store the writes
  1434. checkpointer.put_writes(
  1435. next_config,
  1436. input_writes,
  1437. str(uuid5(UUID(checkpoint["id"]), INPUT)),
  1438. )
  1439. return patch_checkpoint_map(
  1440. next_config, saved.metadata if saved else None
  1441. )
  1442. else:
  1443. raise InvalidUpdateError(
  1444. f"Received no input writes for {self.input_channels}"
  1445. )
  1446. # copy checkpoint
  1447. if as_node == "__copy__":
  1448. if len(updates) > 1:
  1449. raise InvalidUpdateError(
  1450. "Cannot copy checkpoint with multiple updates"
  1451. )
  1452. if saved is None:
  1453. raise InvalidUpdateError("Cannot copy a non-existent checkpoint")
  1454. next_checkpoint = create_checkpoint(checkpoint, None, step)
  1455. # copy checkpoint
  1456. next_config = checkpointer.put(
  1457. saved.parent_config
  1458. or patch_configurable(
  1459. saved.config, {CONFIG_KEY_CHECKPOINT_ID: None}
  1460. ),
  1461. next_checkpoint,
  1462. {
  1463. "source": "fork",
  1464. "step": step + 1,
  1465. "parents": saved.metadata.get("parents", {}),
  1466. },
  1467. {},
  1468. )
  1469. # we want to both clone a checkpoint and update state in one go.
  1470. # reuse the same task ID if possible.
  1471. if isinstance(values, list) and len(values) > 0:
  1472. # figure out the task IDs for the next update checkpoint
  1473. next_tasks = prepare_next_tasks(
  1474. next_checkpoint,
  1475. saved.pending_writes or [],
  1476. self.nodes,
  1477. channels,
  1478. managed,
  1479. next_config,
  1480. step + 2,
  1481. step + 4,
  1482. for_execution=True,
  1483. store=self.store,
  1484. checkpointer=checkpointer,
  1485. manager=None,
  1486. )
  1487. tasks_group_by = defaultdict(list)
  1488. user_group_by: dict[str, list[StateUpdate]] = defaultdict(list)
  1489. for task in next_tasks.values():
  1490. tasks_group_by[task.name].append(task.id)
  1491. for item in values:
  1492. if not isinstance(item, Sequence):
  1493. raise InvalidUpdateError(
  1494. f"Invalid update item: {item} when copying checkpoint"
  1495. )
  1496. values, as_node = item[:2]
  1497. user_group = user_group_by[as_node]
  1498. tasks_group = tasks_group_by[as_node]
  1499. target_idx = len(user_group)
  1500. task_id = (
  1501. tasks_group[target_idx]
  1502. if target_idx < len(tasks_group)
  1503. else None
  1504. )
  1505. user_group_by[as_node].append(
  1506. StateUpdate(values=values, as_node=as_node, task_id=task_id)
  1507. )
  1508. return perform_superstep(
  1509. patch_checkpoint_map(next_config, saved.metadata),
  1510. [item for lst in user_group_by.values() for item in lst],
  1511. )
  1512. return patch_checkpoint_map(next_config, saved.metadata)
  1513. # task ids can be provided in the StateUpdate, but if not,
  1514. # we use the task id generated by prepare_next_tasks
  1515. node_to_task_ids: dict[str, deque[str]] = defaultdict(deque)
  1516. if saved is not None and saved.pending_writes is not None:
  1517. # we call prepare_next_tasks to discover the task IDs that
  1518. # would have been generated, so we can reuse them and
  1519. # properly populate task.result in state history
  1520. next_tasks = prepare_next_tasks(
  1521. checkpoint,
  1522. saved.pending_writes,
  1523. self.nodes,
  1524. channels,
  1525. managed,
  1526. saved.config,
  1527. step + 1,
  1528. step + 3,
  1529. for_execution=True,
  1530. store=self.store,
  1531. checkpointer=checkpointer,
  1532. manager=None,
  1533. )
  1534. # collect task ids to reuse so we can properly attach task results
  1535. for t in next_tasks.values():
  1536. node_to_task_ids[t.name].append(t.id)
  1537. valid_updates: list[tuple[str, dict[str, Any] | None, str | None]] = []
  1538. if len(updates) == 1:
  1539. values, as_node, task_id = updates[0]
  1540. # find last node that updated the state, if not provided
  1541. if as_node is None and len(self.nodes) == 1:
  1542. as_node = tuple(self.nodes)[0]
  1543. elif as_node is None and not any(
  1544. v
  1545. for vv in checkpoint["versions_seen"].values()
  1546. for v in vv.values()
  1547. ):
  1548. if (
  1549. isinstance(self.input_channels, str)
  1550. and self.input_channels in self.nodes
  1551. ):
  1552. as_node = self.input_channels
  1553. elif as_node is None:
  1554. last_seen_by_node = sorted(
  1555. (v, n)
  1556. for n, seen in checkpoint["versions_seen"].items()
  1557. if n in self.nodes
  1558. for v in seen.values()
  1559. )
  1560. # if two nodes updated the state at the same time, it's ambiguous
  1561. if last_seen_by_node:
  1562. if len(last_seen_by_node) == 1:
  1563. as_node = last_seen_by_node[0][1]
  1564. elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]:
  1565. as_node = last_seen_by_node[-1][1]
  1566. if as_node is None:
  1567. raise InvalidUpdateError("Ambiguous update, specify as_node")
  1568. if as_node not in self.nodes:
  1569. raise InvalidUpdateError(f"Node {as_node} does not exist")
  1570. valid_updates.append((as_node, values, task_id))
  1571. else:
  1572. for values, as_node, task_id in updates:
  1573. if as_node is None:
  1574. raise InvalidUpdateError(
  1575. "as_node is required when applying multiple updates"
  1576. )
  1577. if as_node not in self.nodes:
  1578. raise InvalidUpdateError(f"Node {as_node} does not exist")
  1579. valid_updates.append((as_node, values, task_id))
  1580. run_tasks: list[PregelTaskWrites] = []
  1581. run_task_ids: list[str] = []
  1582. for as_node, values, provided_task_id in valid_updates:
  1583. # create task to run all writers of the chosen node
  1584. writers = self.nodes[as_node].flat_writers
  1585. if not writers:
  1586. raise InvalidUpdateError(f"Node {as_node} has no writers")
  1587. writes: deque[tuple[str, Any]] = deque()
  1588. task = PregelTaskWrites((), as_node, writes, [INTERRUPT])
  1589. # get the task ids that were prepared for this node
  1590. # if a task id was provided in the StateUpdate, we use it
  1591. # otherwise, we use the next available task id
  1592. prepared_task_ids = node_to_task_ids.get(as_node, deque())
  1593. task_id = provided_task_id or (
  1594. prepared_task_ids.popleft()
  1595. if prepared_task_ids
  1596. else str(uuid5(UUID(checkpoint["id"]), INTERRUPT))
  1597. )
  1598. run_tasks.append(task)
  1599. run_task_ids.append(task_id)
  1600. run = RunnableSequence(*writers) if len(writers) > 1 else writers[0]
  1601. # execute task
  1602. run.invoke(
  1603. values,
  1604. patch_config(
  1605. config,
  1606. run_name=self.name + "UpdateState",
  1607. configurable={
  1608. # deque.extend is thread-safe
  1609. CONFIG_KEY_SEND: writes.extend,
  1610. CONFIG_KEY_TASK_ID: task_id,
  1611. CONFIG_KEY_READ: partial(
  1612. local_read,
  1613. _scratchpad(
  1614. None,
  1615. [],
  1616. task_id,
  1617. "",
  1618. None,
  1619. step,
  1620. step + 2,
  1621. ),
  1622. channels,
  1623. managed,
  1624. task,
  1625. ),
  1626. },
  1627. ),
  1628. )
  1629. # save task writes
  1630. for task_id, task in zip(run_task_ids, run_tasks):
  1631. # channel writes are saved to current checkpoint
  1632. channel_writes = [w for w in task.writes if w[0] != PUSH]
  1633. if saved and channel_writes:
  1634. checkpointer.put_writes(checkpoint_config, channel_writes, task_id)
  1635. # apply to checkpoint and save
  1636. apply_writes(
  1637. checkpoint,
  1638. channels,
  1639. run_tasks,
  1640. checkpointer.get_next_version,
  1641. self.trigger_to_nodes,
  1642. )
  1643. checkpoint = create_checkpoint(checkpoint, channels, step + 1)
  1644. next_config = checkpointer.put(
  1645. checkpoint_config,
  1646. checkpoint,
  1647. {
  1648. "source": "update",
  1649. "step": step + 1,
  1650. "parents": saved.metadata.get("parents", {}) if saved else {},
  1651. },
  1652. get_new_channel_versions(
  1653. checkpoint_previous_versions, checkpoint["channel_versions"]
  1654. ),
  1655. )
  1656. for task_id, task in zip(run_task_ids, run_tasks):
  1657. # save push writes
  1658. if push_writes := [w for w in task.writes if w[0] == PUSH]:
  1659. checkpointer.put_writes(next_config, push_writes, task_id)
  1660. return patch_checkpoint_map(next_config, saved.metadata if saved else None)
  1661. current_config = patch_configurable(
  1662. config, {CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID])}
  1663. )
  1664. for superstep in supersteps:
  1665. current_config = perform_superstep(current_config, superstep)
  1666. return current_config
  1667. async def abulk_update_state(
  1668. self,
  1669. config: RunnableConfig,
  1670. supersteps: Sequence[Sequence[StateUpdate]],
  1671. ) -> RunnableConfig:
  1672. """Asynchronously apply updates to the graph state in bulk. Requires a checkpointer to be set.
  1673. Args:
  1674. config: The config to apply the updates to.
  1675. supersteps: A list of supersteps, each including a list of updates to apply sequentially to a graph state.
  1676. Each update is a tuple of the form `(values, as_node, task_id)` where `task_id` is optional.
  1677. Raises:
  1678. ValueError: If no checkpointer is set or no updates are provided.
  1679. InvalidUpdateError: If an invalid update is provided.
  1680. Returns:
  1681. RunnableConfig: The updated config.
  1682. """
  1683. checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get(
  1684. CONFIG_KEY_CHECKPOINTER, self.checkpointer
  1685. )
  1686. if not checkpointer:
  1687. raise ValueError("No checkpointer set")
  1688. if len(supersteps) == 0:
  1689. raise ValueError("No supersteps provided")
  1690. if any(len(u) == 0 for u in supersteps):
  1691. raise ValueError("No updates provided")
  1692. # delegate to subgraph
  1693. if (
  1694. checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
  1695. ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
  1696. # remove task_ids from checkpoint_ns
  1697. recast = recast_checkpoint_ns(checkpoint_ns)
  1698. # find the subgraph with the matching name
  1699. async for _, pregel in self.aget_subgraphs(namespace=recast, recurse=True):
  1700. return await pregel.abulk_update_state(
  1701. patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
  1702. supersteps,
  1703. )
  1704. else:
  1705. raise ValueError(f"Subgraph {recast} not found")
  1706. async def aperform_superstep(
  1707. input_config: RunnableConfig, updates: Sequence[StateUpdate]
  1708. ) -> RunnableConfig:
  1709. # get last checkpoint
  1710. config = ensure_config(self.config, input_config)
  1711. saved = await checkpointer.aget_tuple(config)
  1712. if saved is not None:
  1713. self._migrate_checkpoint(saved.checkpoint)
  1714. checkpoint = (
  1715. copy_checkpoint(saved.checkpoint) if saved else empty_checkpoint()
  1716. )
  1717. checkpoint_previous_versions = (
  1718. saved.checkpoint["channel_versions"].copy() if saved else {}
  1719. )
  1720. step = saved.metadata.get("step", -1) if saved else -1
  1721. # merge configurable fields with previous checkpoint config
  1722. checkpoint_config = patch_configurable(
  1723. config,
  1724. {
  1725. CONFIG_KEY_CHECKPOINT_NS: config[CONF].get(
  1726. CONFIG_KEY_CHECKPOINT_NS, ""
  1727. )
  1728. },
  1729. )
  1730. if saved:
  1731. checkpoint_config = patch_configurable(config, saved.config[CONF])
  1732. channels, managed = channels_from_checkpoint(
  1733. self.channels,
  1734. checkpoint,
  1735. )
  1736. values, as_node = updates[0][:2]
  1737. # no values, just clear all tasks
  1738. if values is None and as_node == END:
  1739. if len(updates) > 1:
  1740. raise InvalidUpdateError(
  1741. "Cannot apply multiple updates when clearing state"
  1742. )
  1743. if saved is not None:
  1744. # tasks for this checkpoint
  1745. next_tasks = prepare_next_tasks(
  1746. checkpoint,
  1747. saved.pending_writes or [],
  1748. self.nodes,
  1749. channels,
  1750. managed,
  1751. saved.config,
  1752. step + 1,
  1753. step + 3,
  1754. for_execution=True,
  1755. store=self.store,
  1756. checkpointer=checkpointer,
  1757. manager=None,
  1758. )
  1759. # apply null writes
  1760. if null_writes := [
  1761. w[1:]
  1762. for w in saved.pending_writes or []
  1763. if w[0] == NULL_TASK_ID
  1764. ]:
  1765. apply_writes(
  1766. checkpoint,
  1767. channels,
  1768. [PregelTaskWrites((), INPUT, null_writes, [])],
  1769. checkpointer.get_next_version,
  1770. self.trigger_to_nodes,
  1771. )
  1772. # apply writes from tasks that already ran
  1773. for tid, k, v in saved.pending_writes or []:
  1774. if k in (ERROR, INTERRUPT):
  1775. continue
  1776. if tid not in next_tasks:
  1777. continue
  1778. next_tasks[tid].writes.append((k, v))
  1779. # clear all current tasks
  1780. apply_writes(
  1781. checkpoint,
  1782. channels,
  1783. next_tasks.values(),
  1784. checkpointer.get_next_version,
  1785. self.trigger_to_nodes,
  1786. )
  1787. # save checkpoint
  1788. next_config = await checkpointer.aput(
  1789. checkpoint_config,
  1790. create_checkpoint(checkpoint, channels, step),
  1791. {
  1792. "source": "update",
  1793. "step": step + 1,
  1794. "parents": saved.metadata.get("parents", {}) if saved else {},
  1795. },
  1796. get_new_channel_versions(
  1797. checkpoint_previous_versions, checkpoint["channel_versions"]
  1798. ),
  1799. )
  1800. return patch_checkpoint_map(
  1801. next_config, saved.metadata if saved else None
  1802. )
  1803. # act as an input
  1804. if as_node == INPUT:
  1805. if len(updates) > 1:
  1806. raise InvalidUpdateError(
  1807. "Cannot apply multiple updates when updating as input"
  1808. )
  1809. if input_writes := deque(map_input(self.input_channels, values)):
  1810. apply_writes(
  1811. checkpoint,
  1812. channels,
  1813. [PregelTaskWrites((), INPUT, input_writes, [])],
  1814. checkpointer.get_next_version,
  1815. self.trigger_to_nodes,
  1816. )
  1817. # apply input write to channels
  1818. next_step = (
  1819. step + 1
  1820. if saved and saved.metadata.get("step") is not None
  1821. else -1
  1822. )
  1823. next_config = await checkpointer.aput(
  1824. checkpoint_config,
  1825. create_checkpoint(checkpoint, channels, next_step),
  1826. {
  1827. "source": "input",
  1828. "step": next_step,
  1829. "parents": saved.metadata.get("parents", {})
  1830. if saved
  1831. else {},
  1832. },
  1833. get_new_channel_versions(
  1834. checkpoint_previous_versions,
  1835. checkpoint["channel_versions"],
  1836. ),
  1837. )
  1838. # store the writes
  1839. await checkpointer.aput_writes(
  1840. next_config,
  1841. input_writes,
  1842. str(uuid5(UUID(checkpoint["id"]), INPUT)),
  1843. )
  1844. return patch_checkpoint_map(
  1845. next_config, saved.metadata if saved else None
  1846. )
  1847. else:
  1848. raise InvalidUpdateError(
  1849. f"Received no input writes for {self.input_channels}"
  1850. )
  1851. # no values, copy checkpoint
  1852. if as_node == "__copy__":
  1853. if len(updates) > 1:
  1854. raise InvalidUpdateError(
  1855. "Cannot copy checkpoint with multiple updates"
  1856. )
  1857. if saved is None:
  1858. raise InvalidUpdateError("Cannot copy a non-existent checkpoint")
  1859. next_checkpoint = create_checkpoint(checkpoint, None, step)
  1860. # copy checkpoint
  1861. next_config = await checkpointer.aput(
  1862. saved.parent_config
  1863. or patch_configurable(
  1864. saved.config, {CONFIG_KEY_CHECKPOINT_ID: None}
  1865. ),
  1866. next_checkpoint,
  1867. {
  1868. "source": "fork",
  1869. "step": step + 1,
  1870. "parents": saved.metadata.get("parents", {}),
  1871. },
  1872. {},
  1873. )
  1874. # we want to both clone a checkpoint and update state in one go.
  1875. # reuse the same task ID if possible.
  1876. if isinstance(values, list) and len(values) > 0:
  1877. # figure out the task IDs for the next update checkpoint
  1878. next_tasks = prepare_next_tasks(
  1879. next_checkpoint,
  1880. saved.pending_writes or [],
  1881. self.nodes,
  1882. channels,
  1883. managed,
  1884. next_config,
  1885. step + 2,
  1886. step + 4,
  1887. for_execution=True,
  1888. store=self.store,
  1889. checkpointer=checkpointer,
  1890. manager=None,
  1891. )
  1892. tasks_group_by = defaultdict(list)
  1893. user_group_by: dict[str, list[StateUpdate]] = defaultdict(list)
  1894. for task in next_tasks.values():
  1895. tasks_group_by[task.name].append(task.id)
  1896. for item in values:
  1897. if not isinstance(item, Sequence):
  1898. raise InvalidUpdateError(
  1899. f"Invalid update item: {item} when copying checkpoint"
  1900. )
  1901. values, as_node = item[:2]
  1902. user_group = user_group_by[as_node]
  1903. tasks_group = tasks_group_by[as_node]
  1904. target_idx = len(user_group)
  1905. task_id = (
  1906. tasks_group[target_idx]
  1907. if target_idx < len(tasks_group)
  1908. else None
  1909. )
  1910. user_group_by[as_node].append(
  1911. StateUpdate(values=values, as_node=as_node, task_id=task_id)
  1912. )
  1913. return await aperform_superstep(
  1914. patch_checkpoint_map(next_config, saved.metadata),
  1915. [item for lst in user_group_by.values() for item in lst],
  1916. )
  1917. return patch_checkpoint_map(
  1918. next_config, saved.metadata if saved else None
  1919. )
  1920. # task ids can be provided in the StateUpdate, but if not,
  1921. # we use the task id generated by prepare_next_tasks
  1922. node_to_task_ids: dict[str, deque[str]] = defaultdict(deque)
  1923. if saved is not None and saved.pending_writes is not None:
  1924. # we call prepare_next_tasks to discover the task IDs that
  1925. # would have been generated, so we can reuse them and
  1926. # properly populate task.result in state history
  1927. next_tasks = prepare_next_tasks(
  1928. checkpoint,
  1929. saved.pending_writes,
  1930. self.nodes,
  1931. channels,
  1932. managed,
  1933. saved.config,
  1934. step + 1,
  1935. step + 3,
  1936. for_execution=True,
  1937. store=self.store,
  1938. checkpointer=checkpointer,
  1939. manager=None,
  1940. )
  1941. # collect task ids to reuse so we can properly attach task results
  1942. for t in next_tasks.values():
  1943. node_to_task_ids[t.name].append(t.id)
  1944. valid_updates: list[tuple[str, dict[str, Any] | None, str | None]] = []
  1945. if len(updates) == 1:
  1946. values, as_node, task_id = updates[0]
  1947. # find last node that updated the state, if not provided
  1948. if as_node is None and len(self.nodes) == 1:
  1949. as_node = tuple(self.nodes)[0]
  1950. elif as_node is None and not saved:
  1951. if (
  1952. isinstance(self.input_channels, str)
  1953. and self.input_channels in self.nodes
  1954. ):
  1955. as_node = self.input_channels
  1956. elif as_node is None:
  1957. last_seen_by_node = sorted(
  1958. (v, n)
  1959. for n, seen in checkpoint["versions_seen"].items()
  1960. if n in self.nodes
  1961. for v in seen.values()
  1962. )
  1963. # if two nodes updated the state at the same time, it's ambiguous
  1964. if last_seen_by_node:
  1965. if len(last_seen_by_node) == 1:
  1966. as_node = last_seen_by_node[0][1]
  1967. elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]:
  1968. as_node = last_seen_by_node[-1][1]
  1969. if as_node is None:
  1970. raise InvalidUpdateError("Ambiguous update, specify as_node")
  1971. if as_node not in self.nodes:
  1972. raise InvalidUpdateError(f"Node {as_node} does not exist")
  1973. valid_updates.append((as_node, values, task_id))
  1974. else:
  1975. for values, as_node, task_id in updates:
  1976. if as_node is None:
  1977. raise InvalidUpdateError(
  1978. "as_node is required when applying multiple updates"
  1979. )
  1980. if as_node not in self.nodes:
  1981. raise InvalidUpdateError(f"Node {as_node} does not exist")
  1982. valid_updates.append((as_node, values, task_id))
  1983. run_tasks: list[PregelTaskWrites] = []
  1984. run_task_ids: list[str] = []
  1985. for as_node, values, provided_task_id in valid_updates:
  1986. # create task to run all writers of the chosen node
  1987. writers = self.nodes[as_node].flat_writers
  1988. if not writers:
  1989. raise InvalidUpdateError(f"Node {as_node} has no writers")
  1990. writes: deque[tuple[str, Any]] = deque()
  1991. task = PregelTaskWrites((), as_node, writes, [INTERRUPT])
  1992. # get the task ids that were prepared for this node
  1993. # if a task id was provided in the StateUpdate, we use it
  1994. # otherwise, we use the next available task id
  1995. prepared_task_ids = node_to_task_ids.get(as_node, deque())
  1996. task_id = provided_task_id or (
  1997. prepared_task_ids.popleft()
  1998. if prepared_task_ids
  1999. else str(uuid5(UUID(checkpoint["id"]), INTERRUPT))
  2000. )
  2001. run_tasks.append(task)
  2002. run_task_ids.append(task_id)
  2003. run = RunnableSequence(*writers) if len(writers) > 1 else writers[0]
  2004. # execute task
  2005. await run.ainvoke(
  2006. values,
  2007. patch_config(
  2008. config,
  2009. run_name=self.name + "UpdateState",
  2010. configurable={
  2011. # deque.extend is thread-safe
  2012. CONFIG_KEY_SEND: writes.extend,
  2013. CONFIG_KEY_TASK_ID: task_id,
  2014. CONFIG_KEY_READ: partial(
  2015. local_read,
  2016. _scratchpad(
  2017. None,
  2018. [],
  2019. task_id,
  2020. "",
  2021. None,
  2022. step,
  2023. step + 2,
  2024. ),
  2025. channels,
  2026. managed,
  2027. task,
  2028. ),
  2029. },
  2030. ),
  2031. )
  2032. # save task writes
  2033. for task_id, task in zip(run_task_ids, run_tasks):
  2034. # channel writes are saved to current checkpoint
  2035. channel_writes = [w for w in task.writes if w[0] != PUSH]
  2036. if saved and channel_writes:
  2037. await checkpointer.aput_writes(
  2038. checkpoint_config, channel_writes, task_id
  2039. )
  2040. # apply to checkpoint and save
  2041. apply_writes(
  2042. checkpoint,
  2043. channels,
  2044. run_tasks,
  2045. checkpointer.get_next_version,
  2046. self.trigger_to_nodes,
  2047. )
  2048. checkpoint = create_checkpoint(checkpoint, channels, step + 1)
  2049. # save checkpoint, after applying writes
  2050. next_config = await checkpointer.aput(
  2051. checkpoint_config,
  2052. checkpoint,
  2053. {
  2054. "source": "update",
  2055. "step": step + 1,
  2056. "parents": saved.metadata.get("parents", {}) if saved else {},
  2057. },
  2058. get_new_channel_versions(
  2059. checkpoint_previous_versions, checkpoint["channel_versions"]
  2060. ),
  2061. )
  2062. for task_id, task in zip(run_task_ids, run_tasks):
  2063. # save push writes
  2064. if push_writes := [w for w in task.writes if w[0] == PUSH]:
  2065. await checkpointer.aput_writes(next_config, push_writes, task_id)
  2066. return patch_checkpoint_map(next_config, saved.metadata if saved else None)
  2067. current_config = patch_configurable(
  2068. config, {CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID])}
  2069. )
  2070. for superstep in supersteps:
  2071. current_config = await aperform_superstep(current_config, superstep)
  2072. return current_config
  2073. def update_state(
  2074. self,
  2075. config: RunnableConfig,
  2076. values: dict[str, Any] | Any | None,
  2077. as_node: str | None = None,
  2078. task_id: str | None = None,
  2079. ) -> RunnableConfig:
  2080. """Update the state of the graph with the given values, as if they came from
  2081. node `as_node`. If `as_node` is not provided, it will be set to the last node
  2082. that updated the state, if not ambiguous.
  2083. """
  2084. return self.bulk_update_state(config, [[StateUpdate(values, as_node, task_id)]])
  2085. async def aupdate_state(
  2086. self,
  2087. config: RunnableConfig,
  2088. values: dict[str, Any] | Any,
  2089. as_node: str | None = None,
  2090. task_id: str | None = None,
  2091. ) -> RunnableConfig:
  2092. """Asynchronously update the state of the graph with the given values, as if they came from
  2093. node `as_node`. If `as_node` is not provided, it will be set to the last node
  2094. that updated the state, if not ambiguous.
  2095. """
  2096. return await self.abulk_update_state(
  2097. config, [[StateUpdate(values, as_node, task_id)]]
  2098. )
  2099. def _defaults(
  2100. self,
  2101. config: RunnableConfig,
  2102. *,
  2103. stream_mode: StreamMode | Sequence[StreamMode],
  2104. print_mode: StreamMode | Sequence[StreamMode],
  2105. output_keys: str | Sequence[str] | None,
  2106. interrupt_before: All | Sequence[str] | None,
  2107. interrupt_after: All | Sequence[str] | None,
  2108. durability: Durability | None = None,
  2109. ) -> tuple[
  2110. set[StreamMode],
  2111. str | Sequence[str],
  2112. All | Sequence[str],
  2113. All | Sequence[str],
  2114. BaseCheckpointSaver | None,
  2115. BaseStore | None,
  2116. BaseCache | None,
  2117. Durability,
  2118. ]:
  2119. if config["recursion_limit"] < 1:
  2120. raise ValueError("recursion_limit must be at least 1")
  2121. if output_keys is None:
  2122. output_keys = self.stream_channels_asis
  2123. else:
  2124. validate_keys(output_keys, self.channels)
  2125. interrupt_before = interrupt_before or self.interrupt_before_nodes
  2126. interrupt_after = interrupt_after or self.interrupt_after_nodes
  2127. if isinstance(stream_mode, str):
  2128. stream_modes = {stream_mode}
  2129. else:
  2130. stream_modes = set(stream_mode)
  2131. if isinstance(print_mode, str):
  2132. stream_modes.add(print_mode)
  2133. else:
  2134. stream_modes.update(print_mode)
  2135. if self.checkpointer is False:
  2136. checkpointer: BaseCheckpointSaver | None = None
  2137. elif CONFIG_KEY_CHECKPOINTER in config.get(CONF, {}):
  2138. checkpointer = config[CONF][CONFIG_KEY_CHECKPOINTER]
  2139. elif self.checkpointer is True:
  2140. raise RuntimeError("checkpointer=True cannot be used for root graphs.")
  2141. else:
  2142. checkpointer = self.checkpointer
  2143. if checkpointer and not config.get(CONF):
  2144. raise ValueError(
  2145. "Checkpointer requires one or more of the following 'configurable' "
  2146. "keys: thread_id, checkpoint_ns, checkpoint_id"
  2147. )
  2148. if CONFIG_KEY_RUNTIME in config.get(CONF, {}):
  2149. store: BaseStore | None = config[CONF][CONFIG_KEY_RUNTIME].store
  2150. else:
  2151. store = self.store
  2152. if CONFIG_KEY_CACHE in config.get(CONF, {}):
  2153. cache: BaseCache | None = config[CONF][CONFIG_KEY_CACHE]
  2154. else:
  2155. cache = self.cache
  2156. if durability is None:
  2157. durability = config.get(CONF, {}).get(CONFIG_KEY_DURABILITY, "async")
  2158. return (
  2159. stream_modes,
  2160. output_keys,
  2161. interrupt_before,
  2162. interrupt_after,
  2163. checkpointer,
  2164. store,
  2165. cache,
  2166. durability,
  2167. )
  2168. def stream(
  2169. self,
  2170. input: InputT | Command | None,
  2171. config: RunnableConfig | None = None,
  2172. *,
  2173. context: ContextT | None = None,
  2174. stream_mode: StreamMode | Sequence[StreamMode] | None = None,
  2175. print_mode: StreamMode | Sequence[StreamMode] = (),
  2176. output_keys: str | Sequence[str] | None = None,
  2177. interrupt_before: All | Sequence[str] | None = None,
  2178. interrupt_after: All | Sequence[str] | None = None,
  2179. durability: Durability | None = None,
  2180. subgraphs: bool = False,
  2181. debug: bool | None = None,
  2182. **kwargs: Unpack[DeprecatedKwargs],
  2183. ) -> Iterator[dict[str, Any] | Any]:
  2184. """Stream graph steps for a single input.
  2185. Args:
  2186. input: The input to the graph.
  2187. config: The configuration to use for the run.
  2188. context: The static context to use for the run.
  2189. !!! version-added "Added in version 0.6.0"
  2190. stream_mode: The mode to stream output, defaults to `self.stream_mode`.
  2191. Options are:
  2192. - `"values"`: Emit all values in the state after each step, including interrupts.
  2193. When used with functional API, values are emitted once at the end of the workflow.
  2194. - `"updates"`: Emit only the node or task names and updates returned by the nodes or tasks after each step.
  2195. If multiple updates are made in the same step (e.g. multiple nodes are run) then those updates are emitted separately.
  2196. - `"custom"`: Emit custom data from inside nodes or tasks using `StreamWriter`.
  2197. - `"messages"`: Emit LLM messages token-by-token together with metadata for any LLM invocations inside nodes or tasks.
  2198. - Will be emitted as 2-tuples `(LLM token, metadata)`.
  2199. - `"checkpoints"`: Emit an event when a checkpoint is created, in the same format as returned by `get_state()`.
  2200. - `"tasks"`: Emit events when tasks start and finish, including their results and errors.
  2201. - `"debug"`: Emit debug events with as much information as possible for each step.
  2202. You can pass a list as the `stream_mode` parameter to stream multiple modes at once.
  2203. The streamed outputs will be tuples of `(mode, data)`.
  2204. See [LangGraph streaming guide](https://docs.langchain.com/oss/python/langgraph/streaming) for more details.
  2205. print_mode: Accepts the same values as `stream_mode`, but only prints the output to the console, for debugging purposes.
  2206. Does not affect the output of the graph in any way.
  2207. output_keys: The keys to stream, defaults to all non-context channels.
  2208. interrupt_before: Nodes to interrupt before, defaults to all nodes in the graph.
  2209. interrupt_after: Nodes to interrupt after, defaults to all nodes in the graph.
  2210. durability: The durability mode for the graph execution, defaults to `"async"`.
  2211. Options are:
  2212. - `"sync"`: Changes are persisted synchronously before the next step starts.
  2213. - `"async"`: Changes are persisted asynchronously while the next step executes.
  2214. - `"exit"`: Changes are persisted only when the graph exits.
  2215. subgraphs: Whether to stream events from inside subgraphs, defaults to `False`.
  2216. If `True`, the events will be emitted as tuples `(namespace, data)`,
  2217. or `(namespace, mode, data)` if `stream_mode` is a list,
  2218. where `namespace` is a tuple with the path to the node where a subgraph is invoked,
  2219. e.g. `("parent_node:<task_id>", "child_node:<task_id>")`.
  2220. See [LangGraph streaming guide](https://docs.langchain.com/oss/python/langgraph/streaming) for more details.
  2221. Yields:
  2222. The output of each step in the graph. The output shape depends on the `stream_mode`.
  2223. """
  2224. if (checkpoint_during := kwargs.get("checkpoint_during")) is not None:
  2225. warnings.warn(
  2226. "`checkpoint_during` is deprecated and will be removed. Please use `durability` instead.",
  2227. category=LangGraphDeprecatedSinceV10,
  2228. stacklevel=2,
  2229. )
  2230. if durability is not None:
  2231. raise ValueError(
  2232. "Cannot use both `checkpoint_during` and `durability` parameters. Please use `durability` instead."
  2233. )
  2234. durability = "async" if checkpoint_during else "exit"
  2235. if stream_mode is None:
  2236. # if being called as a node in another graph, default to values mode
  2237. # but don't overwrite stream_mode arg if provided
  2238. stream_mode = (
  2239. "values"
  2240. if config is not None and CONFIG_KEY_TASK_ID in config.get(CONF, {})
  2241. else self.stream_mode
  2242. )
  2243. if debug or self.debug:
  2244. print_mode = ["updates", "values"]
  2245. stream = SyncQueue()
  2246. config = ensure_config(self.config, config)
  2247. callback_manager = get_callback_manager_for_config(config)
  2248. run_manager = callback_manager.on_chain_start(
  2249. None,
  2250. input,
  2251. name=config.get("run_name", self.get_name()),
  2252. run_id=config.get("run_id"),
  2253. )
  2254. try:
  2255. # assign defaults
  2256. (
  2257. stream_modes,
  2258. output_keys,
  2259. interrupt_before_,
  2260. interrupt_after_,
  2261. checkpointer,
  2262. store,
  2263. cache,
  2264. durability_,
  2265. ) = self._defaults(
  2266. config,
  2267. stream_mode=stream_mode,
  2268. print_mode=print_mode,
  2269. output_keys=output_keys,
  2270. interrupt_before=interrupt_before,
  2271. interrupt_after=interrupt_after,
  2272. durability=durability,
  2273. )
  2274. if checkpointer is None and durability is not None:
  2275. warnings.warn(
  2276. "`durability` has no effect when no checkpointer is present.",
  2277. )
  2278. # set up subgraph checkpointing
  2279. if self.checkpointer is True:
  2280. ns = cast(str, config[CONF][CONFIG_KEY_CHECKPOINT_NS])
  2281. config[CONF][CONFIG_KEY_CHECKPOINT_NS] = recast_checkpoint_ns(ns)
  2282. # set up messages stream mode
  2283. if "messages" in stream_modes:
  2284. ns_ = cast(str | None, config[CONF].get(CONFIG_KEY_CHECKPOINT_NS))
  2285. run_manager.inheritable_handlers.append(
  2286. StreamMessagesHandler(
  2287. stream.put,
  2288. subgraphs,
  2289. parent_ns=tuple(ns_.split(NS_SEP)) if ns_ else None,
  2290. )
  2291. )
  2292. # set up custom stream mode
  2293. if "custom" in stream_modes:
  2294. def stream_writer(c: Any) -> None:
  2295. stream.put(
  2296. (
  2297. tuple(
  2298. get_config()[CONF][CONFIG_KEY_CHECKPOINT_NS].split(
  2299. NS_SEP
  2300. )[:-1]
  2301. ),
  2302. "custom",
  2303. c,
  2304. )
  2305. )
  2306. elif CONFIG_KEY_STREAM in config[CONF]:
  2307. stream_writer = config[CONF][CONFIG_KEY_RUNTIME].stream_writer
  2308. else:
  2309. def stream_writer(c: Any) -> None:
  2310. pass
  2311. # set durability mode for subgraphs
  2312. if durability is not None:
  2313. config[CONF][CONFIG_KEY_DURABILITY] = durability_
  2314. runtime = Runtime(
  2315. context=_coerce_context(self.context_schema, context),
  2316. store=store,
  2317. stream_writer=stream_writer,
  2318. previous=None,
  2319. )
  2320. parent_runtime = config[CONF].get(CONFIG_KEY_RUNTIME, DEFAULT_RUNTIME)
  2321. runtime = parent_runtime.merge(runtime)
  2322. config[CONF][CONFIG_KEY_RUNTIME] = runtime
  2323. with SyncPregelLoop(
  2324. input,
  2325. stream=StreamProtocol(stream.put, stream_modes),
  2326. config=config,
  2327. store=store,
  2328. cache=cache,
  2329. checkpointer=checkpointer,
  2330. nodes=self.nodes,
  2331. specs=self.channels,
  2332. output_keys=output_keys,
  2333. input_keys=self.input_channels,
  2334. stream_keys=self.stream_channels_asis,
  2335. interrupt_before=interrupt_before_,
  2336. interrupt_after=interrupt_after_,
  2337. manager=run_manager,
  2338. durability=durability_,
  2339. trigger_to_nodes=self.trigger_to_nodes,
  2340. migrate_checkpoint=self._migrate_checkpoint,
  2341. retry_policy=self.retry_policy,
  2342. cache_policy=self.cache_policy,
  2343. ) as loop:
  2344. # create runner
  2345. runner = PregelRunner(
  2346. submit=config[CONF].get(
  2347. CONFIG_KEY_RUNNER_SUBMIT, weakref.WeakMethod(loop.submit)
  2348. ),
  2349. put_writes=weakref.WeakMethod(loop.put_writes),
  2350. node_finished=config[CONF].get(CONFIG_KEY_NODE_FINISHED),
  2351. )
  2352. # enable subgraph streaming
  2353. if subgraphs:
  2354. loop.config[CONF][CONFIG_KEY_STREAM] = loop.stream
  2355. # enable concurrent streaming
  2356. get_waiter: Callable[[], concurrent.futures.Future[None]] | None = None
  2357. if (
  2358. self.stream_eager
  2359. or subgraphs
  2360. or "messages" in stream_modes
  2361. or "custom" in stream_modes
  2362. ):
  2363. # we are careful to have a single waiter live at any one time
  2364. # because on exit we increment semaphore count by exactly 1
  2365. waiter: concurrent.futures.Future | None = None
  2366. # because sync futures cannot be cancelled, we instead
  2367. # release the stream semaphore on exit, which will cause
  2368. # a pending waiter to return immediately
  2369. loop.stack.callback(stream._count.release)
  2370. def get_waiter() -> concurrent.futures.Future[None]:
  2371. nonlocal waiter
  2372. if waiter is None or waiter.done():
  2373. waiter = loop.submit(stream.wait)
  2374. return waiter
  2375. else:
  2376. return waiter
  2377. # Similarly to Bulk Synchronous Parallel / Pregel model
  2378. # computation proceeds in steps, while there are channel updates.
  2379. # Channel updates from step N are only visible in step N+1
  2380. # channels are guaranteed to be immutable for the duration of the step,
  2381. # with channel updates applied only at the transition between steps.
  2382. while loop.tick():
  2383. for task in loop.match_cached_writes():
  2384. loop.output_writes(task.id, task.writes, cached=True)
  2385. for _ in runner.tick(
  2386. [t for t in loop.tasks.values() if not t.writes],
  2387. timeout=self.step_timeout,
  2388. get_waiter=get_waiter,
  2389. schedule_task=loop.accept_push,
  2390. ):
  2391. # emit output
  2392. yield from _output(
  2393. stream_mode, print_mode, subgraphs, stream.get, queue.Empty
  2394. )
  2395. loop.after_tick()
  2396. # wait for checkpoint
  2397. if durability_ == "sync":
  2398. loop._put_checkpoint_fut.result()
  2399. # emit output
  2400. yield from _output(
  2401. stream_mode, print_mode, subgraphs, stream.get, queue.Empty
  2402. )
  2403. # handle exit
  2404. if loop.status == "out_of_steps":
  2405. msg = create_error_message(
  2406. message=(
  2407. f"Recursion limit of {config['recursion_limit']} reached "
  2408. "without hitting a stop condition. You can increase the "
  2409. "limit by setting the `recursion_limit` config key."
  2410. ),
  2411. error_code=ErrorCode.GRAPH_RECURSION_LIMIT,
  2412. )
  2413. raise GraphRecursionError(msg)
  2414. # set final channel values as run output
  2415. run_manager.on_chain_end(loop.output)
  2416. except BaseException as e:
  2417. run_manager.on_chain_error(e)
  2418. raise
  2419. async def astream(
  2420. self,
  2421. input: InputT | Command | None,
  2422. config: RunnableConfig | None = None,
  2423. *,
  2424. context: ContextT | None = None,
  2425. stream_mode: StreamMode | Sequence[StreamMode] | None = None,
  2426. print_mode: StreamMode | Sequence[StreamMode] = (),
  2427. output_keys: str | Sequence[str] | None = None,
  2428. interrupt_before: All | Sequence[str] | None = None,
  2429. interrupt_after: All | Sequence[str] | None = None,
  2430. durability: Durability | None = None,
  2431. subgraphs: bool = False,
  2432. debug: bool | None = None,
  2433. **kwargs: Unpack[DeprecatedKwargs],
  2434. ) -> AsyncIterator[dict[str, Any] | Any]:
  2435. """Asynchronously stream graph steps for a single input.
  2436. Args:
  2437. input: The input to the graph.
  2438. config: The configuration to use for the run.
  2439. context: The static context to use for the run.
  2440. !!! version-added "Added in version 0.6.0"
  2441. stream_mode: The mode to stream output, defaults to `self.stream_mode`.
  2442. Options are:
  2443. - `"values"`: Emit all values in the state after each step, including interrupts.
  2444. When used with functional API, values are emitted once at the end of the workflow.
  2445. - `"updates"`: Emit only the node or task names and updates returned by the nodes or tasks after each step.
  2446. If multiple updates are made in the same step (e.g. multiple nodes are run) then those updates are emitted separately.
  2447. - `"custom"`: Emit custom data from inside nodes or tasks using `StreamWriter`.
  2448. - `"messages"`: Emit LLM messages token-by-token together with metadata for any LLM invocations inside nodes or tasks.
  2449. - Will be emitted as 2-tuples `(LLM token, metadata)`.
  2450. - `"checkpoints"`: Emit an event when a checkpoint is created, in the same format as returned by `get_state()`.
  2451. - `"tasks"`: Emit events when tasks start and finish, including their results and errors.
  2452. - `"debug"`: Emit debug events with as much information as possible for each step.
  2453. You can pass a list as the `stream_mode` parameter to stream multiple modes at once.
  2454. The streamed outputs will be tuples of `(mode, data)`.
  2455. See [LangGraph streaming guide](https://docs.langchain.com/oss/python/langgraph/streaming) for more details.
  2456. print_mode: Accepts the same values as `stream_mode`, but only prints the output to the console, for debugging purposes.
  2457. Does not affect the output of the graph in any way.
  2458. output_keys: The keys to stream, defaults to all non-context channels.
  2459. interrupt_before: Nodes to interrupt before, defaults to all nodes in the graph.
  2460. interrupt_after: Nodes to interrupt after, defaults to all nodes in the graph.
  2461. durability: The durability mode for the graph execution, defaults to `"async"`.
  2462. Options are:
  2463. - `"sync"`: Changes are persisted synchronously before the next step starts.
  2464. - `"async"`: Changes are persisted asynchronously while the next step executes.
  2465. - `"exit"`: Changes are persisted only when the graph exits.
  2466. subgraphs: Whether to stream events from inside subgraphs, defaults to `False`.
  2467. If `True`, the events will be emitted as tuples `(namespace, data)`,
  2468. or `(namespace, mode, data)` if `stream_mode` is a list,
  2469. where `namespace` is a tuple with the path to the node where a subgraph is invoked,
  2470. e.g. `("parent_node:<task_id>", "child_node:<task_id>")`.
  2471. See [LangGraph streaming guide](https://docs.langchain.com/oss/python/langgraph/streaming) for more details.
  2472. Yields:
  2473. The output of each step in the graph. The output shape depends on the `stream_mode`.
  2474. """
  2475. if (checkpoint_during := kwargs.get("checkpoint_during")) is not None:
  2476. warnings.warn(
  2477. "`checkpoint_during` is deprecated and will be removed. Please use `durability` instead.",
  2478. category=LangGraphDeprecatedSinceV10,
  2479. stacklevel=2,
  2480. )
  2481. if durability is not None:
  2482. raise ValueError(
  2483. "Cannot use both `checkpoint_during` and `durability` parameters. Please use `durability` instead."
  2484. )
  2485. durability = "async" if checkpoint_during else "exit"
  2486. if stream_mode is None:
  2487. # if being called as a node in another graph, default to values mode
  2488. # but don't overwrite stream_mode arg if provided
  2489. stream_mode = (
  2490. "values"
  2491. if config is not None and CONFIG_KEY_TASK_ID in config.get(CONF, {})
  2492. else self.stream_mode
  2493. )
  2494. if debug or self.debug:
  2495. print_mode = ["updates", "values"]
  2496. stream = AsyncQueue()
  2497. aioloop = asyncio.get_running_loop()
  2498. stream_put = cast(
  2499. Callable[[StreamChunk], None],
  2500. partial(aioloop.call_soon_threadsafe, stream.put_nowait),
  2501. )
  2502. config = ensure_config(self.config, config)
  2503. callback_manager = get_async_callback_manager_for_config(config)
  2504. run_manager = await callback_manager.on_chain_start(
  2505. None,
  2506. input,
  2507. name=config.get("run_name", self.get_name()),
  2508. run_id=config.get("run_id"),
  2509. )
  2510. # if running from astream_log() run each proc with streaming
  2511. do_stream = (
  2512. next(
  2513. (
  2514. True
  2515. for h in run_manager.handlers
  2516. if isinstance(h, _StreamingCallbackHandler)
  2517. and not isinstance(h, StreamMessagesHandler)
  2518. ),
  2519. False,
  2520. )
  2521. if _StreamingCallbackHandler is not None
  2522. else False
  2523. )
  2524. try:
  2525. # assign defaults
  2526. (
  2527. stream_modes,
  2528. output_keys,
  2529. interrupt_before_,
  2530. interrupt_after_,
  2531. checkpointer,
  2532. store,
  2533. cache,
  2534. durability_,
  2535. ) = self._defaults(
  2536. config,
  2537. stream_mode=stream_mode,
  2538. print_mode=print_mode,
  2539. output_keys=output_keys,
  2540. interrupt_before=interrupt_before,
  2541. interrupt_after=interrupt_after,
  2542. durability=durability,
  2543. )
  2544. if checkpointer is None and durability is not None:
  2545. warnings.warn(
  2546. "`durability` has no effect when no checkpointer is present.",
  2547. )
  2548. # set up subgraph checkpointing
  2549. if self.checkpointer is True:
  2550. ns = cast(str, config[CONF][CONFIG_KEY_CHECKPOINT_NS])
  2551. config[CONF][CONFIG_KEY_CHECKPOINT_NS] = recast_checkpoint_ns(ns)
  2552. # set up messages stream mode
  2553. if "messages" in stream_modes:
  2554. # namespace can be None in a root level graph?
  2555. ns_ = cast(str | None, config[CONF].get(CONFIG_KEY_CHECKPOINT_NS))
  2556. run_manager.inheritable_handlers.append(
  2557. StreamMessagesHandler(
  2558. stream_put,
  2559. subgraphs,
  2560. parent_ns=tuple(ns_.split(NS_SEP)) if ns_ else None,
  2561. )
  2562. )
  2563. # set up custom stream mode
  2564. def stream_writer(c: Any) -> None:
  2565. aioloop.call_soon_threadsafe(
  2566. stream.put_nowait,
  2567. (
  2568. tuple(
  2569. get_config()[CONF][CONFIG_KEY_CHECKPOINT_NS].split(NS_SEP)[
  2570. :-1
  2571. ]
  2572. ),
  2573. "custom",
  2574. c,
  2575. ),
  2576. )
  2577. if "custom" in stream_modes:
  2578. def stream_writer(c: Any) -> None:
  2579. aioloop.call_soon_threadsafe(
  2580. stream.put_nowait,
  2581. (
  2582. tuple(
  2583. get_config()[CONF][CONFIG_KEY_CHECKPOINT_NS].split(
  2584. NS_SEP
  2585. )[:-1]
  2586. ),
  2587. "custom",
  2588. c,
  2589. ),
  2590. )
  2591. elif CONFIG_KEY_STREAM in config[CONF]:
  2592. stream_writer = config[CONF][CONFIG_KEY_RUNTIME].stream_writer
  2593. else:
  2594. def stream_writer(c: Any) -> None:
  2595. pass
  2596. # set durability mode for subgraphs
  2597. if durability is not None:
  2598. config[CONF][CONFIG_KEY_DURABILITY] = durability_
  2599. runtime = Runtime(
  2600. context=_coerce_context(self.context_schema, context),
  2601. store=store,
  2602. stream_writer=stream_writer,
  2603. previous=None,
  2604. )
  2605. parent_runtime = config[CONF].get(CONFIG_KEY_RUNTIME, DEFAULT_RUNTIME)
  2606. runtime = parent_runtime.merge(runtime)
  2607. config[CONF][CONFIG_KEY_RUNTIME] = runtime
  2608. async with AsyncPregelLoop(
  2609. input,
  2610. stream=StreamProtocol(stream.put_nowait, stream_modes),
  2611. config=config,
  2612. store=store,
  2613. cache=cache,
  2614. checkpointer=checkpointer,
  2615. nodes=self.nodes,
  2616. specs=self.channels,
  2617. output_keys=output_keys,
  2618. input_keys=self.input_channels,
  2619. stream_keys=self.stream_channels_asis,
  2620. interrupt_before=interrupt_before_,
  2621. interrupt_after=interrupt_after_,
  2622. manager=run_manager,
  2623. durability=durability_,
  2624. trigger_to_nodes=self.trigger_to_nodes,
  2625. migrate_checkpoint=self._migrate_checkpoint,
  2626. retry_policy=self.retry_policy,
  2627. cache_policy=self.cache_policy,
  2628. ) as loop:
  2629. # create runner
  2630. runner = PregelRunner(
  2631. submit=config[CONF].get(
  2632. CONFIG_KEY_RUNNER_SUBMIT, weakref.WeakMethod(loop.submit)
  2633. ),
  2634. put_writes=weakref.WeakMethod(loop.put_writes),
  2635. use_astream=do_stream,
  2636. node_finished=config[CONF].get(CONFIG_KEY_NODE_FINISHED),
  2637. )
  2638. # enable subgraph streaming
  2639. if subgraphs:
  2640. loop.config[CONF][CONFIG_KEY_STREAM] = StreamProtocol(
  2641. stream_put, stream_modes
  2642. )
  2643. # enable concurrent streaming
  2644. get_waiter: Callable[[], asyncio.Task[None]] | None = None
  2645. _cleanup_waiter: Callable[[], Awaitable[None]] | None = None
  2646. if (
  2647. self.stream_eager
  2648. or subgraphs
  2649. or "messages" in stream_modes
  2650. or "custom" in stream_modes
  2651. ):
  2652. # Keep a single waiter task alive; ensure cleanup on exit.
  2653. waiter: asyncio.Task[None] | None = None
  2654. def get_waiter() -> asyncio.Task[None]:
  2655. nonlocal waiter
  2656. if waiter is None or waiter.done():
  2657. waiter = aioloop.create_task(stream.wait())
  2658. def _clear(t: asyncio.Task[None]) -> None:
  2659. nonlocal waiter
  2660. if waiter is t:
  2661. waiter = None
  2662. waiter.add_done_callback(_clear)
  2663. return waiter
  2664. async def _cleanup_waiter() -> None:
  2665. """Wake pending waiter and/or cancel+await to avoid pending tasks."""
  2666. nonlocal waiter
  2667. # Try to wake via semaphore like SyncPregelLoop
  2668. with contextlib.suppress(Exception):
  2669. if hasattr(stream, "_count"):
  2670. stream._count.release()
  2671. t = waiter
  2672. waiter = None
  2673. if t is not None and not t.done():
  2674. t.cancel()
  2675. with contextlib.suppress(asyncio.CancelledError):
  2676. await t
  2677. # Similarly to Bulk Synchronous Parallel / Pregel model
  2678. # computation proceeds in steps, while there are channel updates
  2679. # channel updates from step N are only visible in step N+1
  2680. # channels are guaranteed to be immutable for the duration of the step,
  2681. # with channel updates applied only at the transition between steps
  2682. try:
  2683. while loop.tick():
  2684. for task in await loop.amatch_cached_writes():
  2685. loop.output_writes(task.id, task.writes, cached=True)
  2686. async for _ in runner.atick(
  2687. [t for t in loop.tasks.values() if not t.writes],
  2688. timeout=self.step_timeout,
  2689. get_waiter=get_waiter,
  2690. schedule_task=loop.aaccept_push,
  2691. ):
  2692. # emit output
  2693. for o in _output(
  2694. stream_mode,
  2695. print_mode,
  2696. subgraphs,
  2697. stream.get_nowait,
  2698. asyncio.QueueEmpty,
  2699. ):
  2700. yield o
  2701. loop.after_tick()
  2702. # wait for checkpoint
  2703. if durability_ == "sync":
  2704. await cast(asyncio.Future, loop._put_checkpoint_fut)
  2705. finally:
  2706. # ensure waiter doesn't remain pending on cancel/shutdown
  2707. if _cleanup_waiter is not None:
  2708. await _cleanup_waiter()
  2709. # emit output
  2710. for o in _output(
  2711. stream_mode,
  2712. print_mode,
  2713. subgraphs,
  2714. stream.get_nowait,
  2715. asyncio.QueueEmpty,
  2716. ):
  2717. yield o
  2718. # handle exit
  2719. if loop.status == "out_of_steps":
  2720. msg = create_error_message(
  2721. message=(
  2722. f"Recursion limit of {config['recursion_limit']} reached "
  2723. "without hitting a stop condition. You can increase the "
  2724. "limit by setting the `recursion_limit` config key."
  2725. ),
  2726. error_code=ErrorCode.GRAPH_RECURSION_LIMIT,
  2727. )
  2728. raise GraphRecursionError(msg)
  2729. # set final channel values as run output
  2730. await run_manager.on_chain_end(loop.output)
  2731. except BaseException as e:
  2732. await asyncio.shield(run_manager.on_chain_error(e))
  2733. raise
  2734. def invoke(
  2735. self,
  2736. input: InputT | Command | None,
  2737. config: RunnableConfig | None = None,
  2738. *,
  2739. context: ContextT | None = None,
  2740. stream_mode: StreamMode = "values",
  2741. print_mode: StreamMode | Sequence[StreamMode] = (),
  2742. output_keys: str | Sequence[str] | None = None,
  2743. interrupt_before: All | Sequence[str] | None = None,
  2744. interrupt_after: All | Sequence[str] | None = None,
  2745. durability: Durability | None = None,
  2746. **kwargs: Any,
  2747. ) -> dict[str, Any] | Any:
  2748. """Run the graph with a single input and config.
  2749. Args:
  2750. input: The input data for the graph. It can be a dictionary or any other type.
  2751. config: The configuration for the graph run.
  2752. context: The static context to use for the run.
  2753. !!! version-added "Added in version 0.6.0"
  2754. stream_mode: The stream mode for the graph run.
  2755. print_mode: Accepts the same values as `stream_mode`, but only prints the output to the console, for debugging purposes.
  2756. Does not affect the output of the graph in any way.
  2757. output_keys: The output keys to retrieve from the graph run.
  2758. interrupt_before: The nodes to interrupt the graph run before.
  2759. interrupt_after: The nodes to interrupt the graph run after.
  2760. durability: The durability mode for the graph execution, defaults to `"async"`.
  2761. Options are:
  2762. - `"sync"`: Changes are persisted synchronously before the next step starts.
  2763. - `"async"`: Changes are persisted asynchronously while the next step executes.
  2764. - `"exit"`: Changes are persisted only when the graph exits.
  2765. **kwargs: Additional keyword arguments to pass to the graph run.
  2766. Returns:
  2767. The output of the graph run. If `stream_mode` is `"values"`, it returns the latest output.
  2768. If `stream_mode` is not `"values"`, it returns a list of output chunks.
  2769. """
  2770. output_keys = output_keys if output_keys is not None else self.output_channels
  2771. latest: dict[str, Any] | Any = None
  2772. chunks: list[dict[str, Any] | Any] = []
  2773. interrupts: list[Interrupt] = []
  2774. for chunk in self.stream(
  2775. input,
  2776. config,
  2777. context=context,
  2778. stream_mode=["updates", "values"]
  2779. if stream_mode == "values"
  2780. else stream_mode,
  2781. print_mode=print_mode,
  2782. output_keys=output_keys,
  2783. interrupt_before=interrupt_before,
  2784. interrupt_after=interrupt_after,
  2785. durability=durability,
  2786. **kwargs,
  2787. ):
  2788. if stream_mode == "values":
  2789. if len(chunk) == 2:
  2790. mode, payload = cast(tuple[StreamMode, Any], chunk)
  2791. else:
  2792. _, mode, payload = cast(
  2793. tuple[tuple[str, ...], StreamMode, Any], chunk
  2794. )
  2795. if (
  2796. mode == "updates"
  2797. and isinstance(payload, dict)
  2798. and (ints := payload.get(INTERRUPT)) is not None
  2799. ):
  2800. interrupts.extend(ints)
  2801. elif mode == "values":
  2802. latest = payload
  2803. else:
  2804. chunks.append(chunk)
  2805. if stream_mode == "values":
  2806. if interrupts:
  2807. return (
  2808. {**latest, INTERRUPT: interrupts}
  2809. if isinstance(latest, dict)
  2810. else {INTERRUPT: interrupts}
  2811. )
  2812. return latest
  2813. else:
  2814. return chunks
  2815. async def ainvoke(
  2816. self,
  2817. input: InputT | Command | None,
  2818. config: RunnableConfig | None = None,
  2819. *,
  2820. context: ContextT | None = None,
  2821. stream_mode: StreamMode = "values",
  2822. print_mode: StreamMode | Sequence[StreamMode] = (),
  2823. output_keys: str | Sequence[str] | None = None,
  2824. interrupt_before: All | Sequence[str] | None = None,
  2825. interrupt_after: All | Sequence[str] | None = None,
  2826. durability: Durability | None = None,
  2827. **kwargs: Any,
  2828. ) -> dict[str, Any] | Any:
  2829. """Asynchronously run the graph with a single input and config.
  2830. Args:
  2831. input: The input data for the graph. It can be a dictionary or any other type.
  2832. config: The configuration for the graph run.
  2833. context: The static context to use for the run.
  2834. !!! version-added "Added in version 0.6.0"
  2835. stream_mode: The stream mode for the graph run.
  2836. print_mode: Accepts the same values as `stream_mode`, but only prints the output to the console, for debugging purposes.
  2837. Does not affect the output of the graph in any way.
  2838. output_keys: The output keys to retrieve from the graph run.
  2839. interrupt_before: The nodes to interrupt the graph run before.
  2840. interrupt_after: The nodes to interrupt the graph run after.
  2841. durability: The durability mode for the graph execution, defaults to `"async"`.
  2842. Options are:
  2843. - `"sync"`: Changes are persisted synchronously before the next step starts.
  2844. - `"async"`: Changes are persisted asynchronously while the next step executes.
  2845. - `"exit"`: Changes are persisted only when the graph exits.
  2846. **kwargs: Additional keyword arguments to pass to the graph run.
  2847. Returns:
  2848. The output of the graph run. If `stream_mode` is `"values"`, it returns the latest output.
  2849. If `stream_mode` is not `"values"`, it returns a list of output chunks.
  2850. """
  2851. output_keys = output_keys if output_keys is not None else self.output_channels
  2852. latest: dict[str, Any] | Any = None
  2853. chunks: list[dict[str, Any] | Any] = []
  2854. interrupts: list[Interrupt] = []
  2855. async for chunk in self.astream(
  2856. input,
  2857. config,
  2858. context=context,
  2859. stream_mode=["updates", "values"]
  2860. if stream_mode == "values"
  2861. else stream_mode,
  2862. print_mode=print_mode,
  2863. output_keys=output_keys,
  2864. interrupt_before=interrupt_before,
  2865. interrupt_after=interrupt_after,
  2866. durability=durability,
  2867. **kwargs,
  2868. ):
  2869. if stream_mode == "values":
  2870. if len(chunk) == 2:
  2871. mode, payload = cast(tuple[StreamMode, Any], chunk)
  2872. else:
  2873. _, mode, payload = cast(
  2874. tuple[tuple[str, ...], StreamMode, Any], chunk
  2875. )
  2876. if (
  2877. mode == "updates"
  2878. and isinstance(payload, dict)
  2879. and (ints := payload.get(INTERRUPT)) is not None
  2880. ):
  2881. interrupts.extend(ints)
  2882. elif mode == "values":
  2883. latest = payload
  2884. else:
  2885. chunks.append(chunk)
  2886. if stream_mode == "values":
  2887. if interrupts:
  2888. return (
  2889. {**latest, INTERRUPT: interrupts}
  2890. if isinstance(latest, dict)
  2891. else {INTERRUPT: interrupts}
  2892. )
  2893. return latest
  2894. else:
  2895. return chunks
  2896. def clear_cache(self, nodes: Sequence[str] | None = None) -> None:
  2897. """Clear the cache for the given nodes."""
  2898. if not self.cache:
  2899. raise ValueError("No cache is set for this graph. Cannot clear cache.")
  2900. nodes = nodes or self.nodes.keys()
  2901. # collect namespaces to clear
  2902. namespaces: list[tuple[str, ...]] = []
  2903. for node in nodes:
  2904. if node in self.nodes:
  2905. namespaces.append(
  2906. (
  2907. CACHE_NS_WRITES,
  2908. (identifier(self.nodes[node]) or "__dynamic__"),
  2909. node,
  2910. ),
  2911. )
  2912. # clear cache
  2913. self.cache.clear(namespaces)
  2914. async def aclear_cache(self, nodes: Sequence[str] | None = None) -> None:
  2915. """Asynchronously clear the cache for the given nodes."""
  2916. if not self.cache:
  2917. raise ValueError("No cache is set for this graph. Cannot clear cache.")
  2918. nodes = nodes or self.nodes.keys()
  2919. # collect namespaces to clear
  2920. namespaces: list[tuple[str, ...]] = []
  2921. for node in nodes:
  2922. if node in self.nodes:
  2923. namespaces.append(
  2924. (
  2925. CACHE_NS_WRITES,
  2926. (identifier(self.nodes[node]) or "__dynamic__"),
  2927. node,
  2928. ),
  2929. )
  2930. # clear cache
  2931. await self.cache.aclear(namespaces)
  2932. def _trigger_to_nodes(nodes: dict[str, PregelNode]) -> Mapping[str, Sequence[str]]:
  2933. """Index from a trigger to nodes that depend on it."""
  2934. trigger_to_nodes: defaultdict[str, list[str]] = defaultdict(list)
  2935. for name, node in nodes.items():
  2936. for trigger in node.triggers:
  2937. trigger_to_nodes[trigger].append(name)
  2938. return dict(trigger_to_nodes)
  2939. def _output(
  2940. stream_mode: StreamMode | Sequence[StreamMode],
  2941. print_mode: StreamMode | Sequence[StreamMode],
  2942. stream_subgraphs: bool,
  2943. getter: Callable[[], tuple[tuple[str, ...], str, Any]],
  2944. empty_exc: type[Exception],
  2945. ) -> Iterator:
  2946. while True:
  2947. try:
  2948. ns, mode, payload = getter()
  2949. except empty_exc:
  2950. break
  2951. if mode in print_mode:
  2952. if stream_subgraphs and ns:
  2953. print(
  2954. " ".join(
  2955. (
  2956. get_bolded_text(f"[{mode}]"),
  2957. get_colored_text(f"[graph={ns}]", color="yellow"),
  2958. repr(payload),
  2959. )
  2960. )
  2961. )
  2962. else:
  2963. print(
  2964. " ".join(
  2965. (
  2966. get_bolded_text(f"[{mode}]"),
  2967. repr(payload),
  2968. )
  2969. )
  2970. )
  2971. if mode in stream_mode:
  2972. if stream_subgraphs and isinstance(stream_mode, list):
  2973. yield (ns, mode, payload)
  2974. elif isinstance(stream_mode, list):
  2975. yield (mode, payload)
  2976. elif stream_subgraphs:
  2977. yield (ns, payload)
  2978. else:
  2979. yield payload
  2980. def _coerce_context(
  2981. context_schema: type[ContextT] | None, context: Any
  2982. ) -> ContextT | None:
  2983. """Coerce context input to the appropriate schema type.
  2984. If context is a dict and context_schema is a dataclass or pydantic model, we coerce.
  2985. Else, we return the context as-is.
  2986. Args:
  2987. context_schema: The schema type to coerce to (BaseModel, dataclass, or TypedDict)
  2988. context: The context value to coerce
  2989. Returns:
  2990. The coerced context value or None if context is None
  2991. """
  2992. if context is None:
  2993. return None
  2994. if context_schema is None:
  2995. return context
  2996. schema_is_class = issubclass(context_schema, BaseModel) or is_dataclass(
  2997. context_schema
  2998. )
  2999. if isinstance(context, dict) and schema_is_class:
  3000. return context_schema(**context) # type: ignore[misc]
  3001. return cast(ContextT, context)