| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319 |
- from __future__ import annotations
- import asyncio
- import concurrent
- import concurrent.futures
- import contextlib
- import queue
- import warnings
- import weakref
- from collections import defaultdict, deque
- from collections.abc import (
- AsyncIterator,
- Awaitable,
- Callable,
- Iterator,
- Mapping,
- Sequence,
- )
- from dataclasses import is_dataclass
- from functools import partial
- from inspect import isclass
- from typing import (
- Any,
- Generic,
- cast,
- get_type_hints,
- )
- from uuid import UUID, uuid5
- from langchain_core.globals import get_debug
- from langchain_core.runnables import (
- RunnableSequence,
- )
- from langchain_core.runnables.base import Input, Output
- from langchain_core.runnables.config import (
- RunnableConfig,
- get_async_callback_manager_for_config,
- get_callback_manager_for_config,
- )
- from langchain_core.runnables.graph import Graph
- from langgraph.cache.base import BaseCache
- from langgraph.checkpoint.base import (
- BaseCheckpointSaver,
- Checkpoint,
- CheckpointTuple,
- )
- from langgraph.store.base import BaseStore
- from pydantic import BaseModel, TypeAdapter
- from typing_extensions import Self, Unpack, deprecated, is_typeddict
- from langgraph._internal._config import (
- ensure_config,
- merge_configs,
- patch_checkpoint_map,
- patch_config,
- patch_configurable,
- recast_checkpoint_ns,
- )
- from langgraph._internal._constants import (
- CACHE_NS_WRITES,
- CONF,
- CONFIG_KEY_CACHE,
- CONFIG_KEY_CHECKPOINT_ID,
- CONFIG_KEY_CHECKPOINT_NS,
- CONFIG_KEY_CHECKPOINTER,
- CONFIG_KEY_DURABILITY,
- CONFIG_KEY_NODE_FINISHED,
- CONFIG_KEY_READ,
- CONFIG_KEY_RUNNER_SUBMIT,
- CONFIG_KEY_RUNTIME,
- CONFIG_KEY_SEND,
- CONFIG_KEY_STREAM,
- CONFIG_KEY_TASK_ID,
- CONFIG_KEY_THREAD_ID,
- ERROR,
- INPUT,
- INTERRUPT,
- NS_END,
- NS_SEP,
- NULL_TASK_ID,
- PUSH,
- TASKS,
- )
- from langgraph._internal._pydantic import create_model
- from langgraph._internal._queue import ( # type: ignore[attr-defined]
- AsyncQueue,
- SyncQueue,
- )
- from langgraph._internal._runnable import (
- Runnable,
- RunnableLike,
- RunnableSeq,
- coerce_to_runnable,
- )
- from langgraph._internal._typing import MISSING, DeprecatedKwargs
- from langgraph.channels.base import BaseChannel
- from langgraph.channels.topic import Topic
- from langgraph.config import get_config
- from langgraph.constants import END
- from langgraph.errors import (
- ErrorCode,
- GraphRecursionError,
- InvalidUpdateError,
- create_error_message,
- )
- from langgraph.managed.base import ManagedValueSpec
- from langgraph.pregel._algo import (
- PregelTaskWrites,
- _scratchpad,
- apply_writes,
- local_read,
- prepare_next_tasks,
- )
- from langgraph.pregel._call import identifier
- from langgraph.pregel._checkpoint import (
- channels_from_checkpoint,
- copy_checkpoint,
- create_checkpoint,
- empty_checkpoint,
- )
- from langgraph.pregel._draw import draw_graph
- from langgraph.pregel._io import map_input, read_channels
- from langgraph.pregel._loop import AsyncPregelLoop, SyncPregelLoop
- from langgraph.pregel._messages import StreamMessagesHandler
- from langgraph.pregel._read import DEFAULT_BOUND, PregelNode
- from langgraph.pregel._retry import RetryPolicy
- from langgraph.pregel._runner import PregelRunner
- from langgraph.pregel._utils import get_new_channel_versions
- from langgraph.pregel._validate import validate_graph, validate_keys
- from langgraph.pregel._write import ChannelWrite, ChannelWriteEntry
- from langgraph.pregel.debug import get_bolded_text, get_colored_text, tasks_w_writes
- from langgraph.pregel.protocol import PregelProtocol, StreamChunk, StreamProtocol
- from langgraph.runtime import DEFAULT_RUNTIME, Runtime
- from langgraph.types import (
- All,
- CachePolicy,
- Checkpointer,
- Command,
- Durability,
- Interrupt,
- Send,
- StateSnapshot,
- StateUpdate,
- StreamMode,
- )
- from langgraph.typing import ContextT, InputT, OutputT, StateT
- from langgraph.warnings import LangGraphDeprecatedSinceV10
- try:
- from langchain_core.tracers._streaming import _StreamingCallbackHandler
- except ImportError:
- _StreamingCallbackHandler = None # type: ignore
- __all__ = ("NodeBuilder", "Pregel")
- _WriteValue = Callable[[Input], Output] | Any
- class NodeBuilder:
- __slots__ = (
- "_channels",
- "_triggers",
- "_tags",
- "_metadata",
- "_writes",
- "_bound",
- "_retry_policy",
- "_cache_policy",
- )
- _channels: str | list[str]
- _triggers: list[str]
- _tags: list[str]
- _metadata: dict[str, Any]
- _writes: list[ChannelWriteEntry]
- _bound: Runnable
- _retry_policy: list[RetryPolicy]
- _cache_policy: CachePolicy | None
- def __init__(
- self,
- ) -> None:
- self._channels = []
- self._triggers = []
- self._tags = []
- self._metadata = {}
- self._writes = []
- self._bound = DEFAULT_BOUND
- self._retry_policy = []
- self._cache_policy = None
- def subscribe_only(
- self,
- channel: str,
- ) -> Self:
- """Subscribe to a single channel."""
- if not self._channels:
- self._channels = channel
- else:
- raise ValueError(
- "Cannot subscribe to single channels when other channels are already subscribed to"
- )
- self._triggers.append(channel)
- return self
- def subscribe_to(
- self,
- *channels: str,
- read: bool = True,
- ) -> Self:
- """Add channels to subscribe to.
- Node will be invoked when any of these channels are updated, with a dict of the
- channel values as input.
- Args:
- channels: Channel name(s) to subscribe to
- read: If `True`, the channels will be included in the input to the node.
- Otherwise, they will trigger the node without being sent in input.
- Returns:
- Self for chaining
- """
- if isinstance(self._channels, str):
- raise ValueError(
- "Cannot subscribe to channels when subscribed to a single channel"
- )
- if read:
- if not self._channels:
- self._channels = list(channels)
- else:
- self._channels.extend(channels)
- if isinstance(channels, str):
- self._triggers.append(channels)
- else:
- self._triggers.extend(channels)
- return self
- def read_from(
- self,
- *channels: str,
- ) -> Self:
- """Adds the specified channels to read from, without subscribing to them."""
- assert isinstance(self._channels, list), (
- "Cannot read additional channels when subscribed to single channels"
- )
- self._channels.extend(channels)
- return self
- def do(
- self,
- node: RunnableLike,
- ) -> Self:
- """Adds the specified node."""
- if self._bound is not DEFAULT_BOUND:
- self._bound = RunnableSeq(
- self._bound, coerce_to_runnable(node, name=None, trace=True)
- )
- else:
- self._bound = coerce_to_runnable(node, name=None, trace=True)
- return self
- def write_to(
- self,
- *channels: str | ChannelWriteEntry,
- **kwargs: _WriteValue,
- ) -> Self:
- """Add channel writes.
- Args:
- *channels: Channel names to write to.
- **kwargs: Channel name and value mappings.
- Returns:
- Self for chaining
- """
- self._writes.extend(
- ChannelWriteEntry(c) if isinstance(c, str) else c for c in channels
- )
- self._writes.extend(
- ChannelWriteEntry(k, mapper=v)
- if callable(v)
- else ChannelWriteEntry(k, value=v)
- for k, v in kwargs.items()
- )
- return self
- def meta(self, *tags: str, **metadata: Any) -> Self:
- """Add tags or metadata to the node."""
- self._tags.extend(tags)
- self._metadata.update(metadata)
- return self
- def add_retry_policies(self, *policies: RetryPolicy) -> Self:
- """Adds retry policies to the node."""
- self._retry_policy.extend(policies)
- return self
- def add_cache_policy(self, policy: CachePolicy) -> Self:
- """Adds cache policies to the node."""
- self._cache_policy = policy
- return self
- def build(self) -> PregelNode:
- """Builds the node."""
- return PregelNode(
- channels=self._channels,
- triggers=self._triggers,
- tags=self._tags,
- metadata=self._metadata,
- writers=[ChannelWrite(self._writes)],
- bound=self._bound,
- retry_policy=self._retry_policy,
- cache_policy=self._cache_policy,
- )
- class Pregel(
- PregelProtocol[StateT, ContextT, InputT, OutputT],
- Generic[StateT, ContextT, InputT, OutputT],
- ):
- """Pregel manages the runtime behavior for LangGraph applications.
- ## Overview
- Pregel combines [**actors**](https://en.wikipedia.org/wiki/Actor_model)
- and **channels** into a single application.
- **Actors** read data from channels and write data to channels.
- Pregel organizes the execution of the application into multiple steps,
- following the **Pregel Algorithm**/**Bulk Synchronous Parallel** model.
- Each step consists of three phases:
- - **Plan**: Determine which **actors** to execute in this step. For example,
- in the first step, select the **actors** that subscribe to the special
- **input** channels; in subsequent steps,
- select the **actors** that subscribe to channels updated in the previous step.
- - **Execution**: Execute all selected **actors** in parallel,
- until all complete, or one fails, or a timeout is reached. During this
- phase, channel updates are invisible to actors until the next step.
- - **Update**: Update the channels with the values written by the **actors**
- in this step.
- Repeat until no **actors** are selected for execution, or a maximum number of
- steps is reached.
- ## Actors
- An **actor** is a `PregelNode`.
- It subscribes to channels, reads data from them, and writes data to them.
- It can be thought of as an **actor** in the Pregel algorithm.
- `PregelNodes` implement LangChain's
- Runnable interface.
- ## Channels
- Channels are used to communicate between actors (`PregelNodes`).
- Each channel has a value type, an update type, and an update function – which
- takes a sequence of updates and
- modifies the stored value. Channels can be used to send data from one chain to
- another, or to send data from a chain to itself in a future step. LangGraph
- provides a number of built-in channels:
- ### Basic channels: LastValue and Topic
- - `LastValue`: The default channel, stores the last value sent to the channel,
- useful for input and output values, or for sending data from one step to the next
- - `Topic`: A configurable PubSub Topic, useful for sending multiple values
- between *actors*, or for accumulating output. Can be configured to deduplicate
- values, and/or to accumulate values over the course of multiple steps.
- ### Advanced channels: Context and BinaryOperatorAggregate
- - `Context`: exposes the value of a context manager, managing its lifecycle.
- Useful for accessing external resources that require setup and/or teardown. e.g.
- `client = Context(httpx.Client)`
- - `BinaryOperatorAggregate`: stores a persistent value, updated by applying
- a binary operator to the current value and each update
- sent to the channel, useful for computing aggregates over multiple steps. e.g.
- `total = BinaryOperatorAggregate(int, operator.add)`
- ## Examples
- Most users will interact with Pregel via a
- [StateGraph (Graph API)][langgraph.graph.StateGraph] or via an
- [entrypoint (Functional API)][langgraph.func.entrypoint].
- However, for **advanced** use cases, Pregel can be used directly. If you're
- not sure whether you need to use Pregel directly, then the answer is probably no
- - you should use the Graph API or Functional API instead. These are higher-level
- interfaces that will compile down to Pregel under the hood.
- Here are some examples to give you a sense of how it works:
- Example: Single node application
- ```python
- from langgraph.channels import EphemeralValue
- from langgraph.pregel import Pregel, NodeBuilder
- node1 = (
- NodeBuilder().subscribe_only("a")
- .do(lambda x: x + x)
- .write_to("b")
- )
- app = Pregel(
- nodes={"node1": node1},
- channels={
- "a": EphemeralValue(str),
- "b": EphemeralValue(str),
- },
- input_channels=["a"],
- output_channels=["b"],
- )
- app.invoke({"a": "foo"})
- ```
- ```con
- {'b': 'foofoo'}
- ```
- Example: Using multiple nodes and multiple output channels
- ```python
- from langgraph.channels import LastValue, EphemeralValue
- from langgraph.pregel import Pregel, NodeBuilder
- node1 = (
- NodeBuilder().subscribe_only("a")
- .do(lambda x: x + x)
- .write_to("b")
- )
- node2 = (
- NodeBuilder().subscribe_to("b")
- .do(lambda x: x["b"] + x["b"])
- .write_to("c")
- )
- app = Pregel(
- nodes={"node1": node1, "node2": node2},
- channels={
- "a": EphemeralValue(str),
- "b": LastValue(str),
- "c": EphemeralValue(str),
- },
- input_channels=["a"],
- output_channels=["b", "c"],
- )
- app.invoke({"a": "foo"})
- ```
- ```con
- {'b': 'foofoo', 'c': 'foofoofoofoo'}
- ```
- Example: Using a Topic channel
- ```python
- from langgraph.channels import LastValue, EphemeralValue, Topic
- from langgraph.pregel import Pregel, NodeBuilder
- node1 = (
- NodeBuilder().subscribe_only("a")
- .do(lambda x: x + x)
- .write_to("b", "c")
- )
- node2 = (
- NodeBuilder().subscribe_only("b")
- .do(lambda x: x + x)
- .write_to("c")
- )
- app = Pregel(
- nodes={"node1": node1, "node2": node2},
- channels={
- "a": EphemeralValue(str),
- "b": EphemeralValue(str),
- "c": Topic(str, accumulate=True),
- },
- input_channels=["a"],
- output_channels=["c"],
- )
- app.invoke({"a": "foo"})
- ```
- ```pycon
- {"c": ["foofoo", "foofoofoofoo"]}
- ```
- Example: Using a `BinaryOperatorAggregate` channel
- ```python
- from langgraph.channels import EphemeralValue, BinaryOperatorAggregate
- from langgraph.pregel import Pregel, NodeBuilder
- node1 = (
- NodeBuilder().subscribe_only("a")
- .do(lambda x: x + x)
- .write_to("b", "c")
- )
- node2 = (
- NodeBuilder().subscribe_only("b")
- .do(lambda x: x + x)
- .write_to("c")
- )
- def reducer(current, update):
- if current:
- return current + " | " + update
- else:
- return update
- app = Pregel(
- nodes={"node1": node1, "node2": node2},
- channels={
- "a": EphemeralValue(str),
- "b": EphemeralValue(str),
- "c": BinaryOperatorAggregate(str, operator=reducer),
- },
- input_channels=["a"],
- output_channels=["c"],
- )
- app.invoke({"a": "foo"})
- ```
- ```con
- {'c': 'foofoo | foofoofoofoo'}
- ```
- Example: Introducing a cycle
- This example demonstrates how to introduce a cycle in the graph, by having
- a chain write to a channel it subscribes to.
- Execution will continue until a `None` value is written to the channel.
- ```python
- from langgraph.channels import EphemeralValue
- from langgraph.pregel import Pregel, NodeBuilder, ChannelWriteEntry
- example_node = (
- NodeBuilder()
- .subscribe_only("value")
- .do(lambda x: x + x if len(x) < 10 else None)
- .write_to(ChannelWriteEntry(channel="value", skip_none=True))
- )
- app = Pregel(
- nodes={"example_node": example_node},
- channels={
- "value": EphemeralValue(str),
- },
- input_channels=["value"],
- output_channels=["value"],
- )
- app.invoke({"value": "a"})
- ```
- ```con
- {'value': 'aaaaaaaaaaaaaaaa'}
- ```
- """
- nodes: dict[str, PregelNode]
- channels: dict[str, BaseChannel | ManagedValueSpec]
- stream_mode: StreamMode = "values"
- """Mode to stream output, defaults to 'values'."""
- stream_eager: bool = False
- """Whether to force emitting stream events eagerly, automatically turned on
- for stream_mode "messages" and "custom"."""
- output_channels: str | Sequence[str]
- stream_channels: str | Sequence[str] | None = None
- """Channels to stream, defaults to all channels not in reserved channels"""
- interrupt_after_nodes: All | Sequence[str]
- interrupt_before_nodes: All | Sequence[str]
- input_channels: str | Sequence[str]
- step_timeout: float | None = None
- """Maximum time to wait for a step to complete, in seconds."""
- debug: bool
- """Whether to print debug information during execution."""
- checkpointer: Checkpointer = None
- """`Checkpointer` used to save and load graph state."""
- store: BaseStore | None = None
- """Memory store to use for SharedValues."""
- cache: BaseCache | None = None
- """Cache to use for storing node results."""
- retry_policy: Sequence[RetryPolicy] = ()
- """Retry policies to use when running tasks. Empty set disables retries."""
- cache_policy: CachePolicy | None = None
- """Cache policy to use for all nodes. Can be overridden by individual nodes."""
- context_schema: type[ContextT] | None = None
- """Specifies the schema for the context object that will be passed to the workflow."""
- config: RunnableConfig | None = None
- name: str = "LangGraph"
- trigger_to_nodes: Mapping[str, Sequence[str]]
- def __init__(
- self,
- *,
- nodes: dict[str, PregelNode | NodeBuilder],
- channels: dict[str, BaseChannel | ManagedValueSpec] | None,
- auto_validate: bool = True,
- stream_mode: StreamMode = "values",
- stream_eager: bool = False,
- output_channels: str | Sequence[str],
- stream_channels: str | Sequence[str] | None = None,
- interrupt_after_nodes: All | Sequence[str] = (),
- interrupt_before_nodes: All | Sequence[str] = (),
- input_channels: str | Sequence[str],
- step_timeout: float | None = None,
- debug: bool | None = None,
- checkpointer: BaseCheckpointSaver | None = None,
- store: BaseStore | None = None,
- cache: BaseCache | None = None,
- retry_policy: RetryPolicy | Sequence[RetryPolicy] = (),
- cache_policy: CachePolicy | None = None,
- context_schema: type[ContextT] | None = None,
- config: RunnableConfig | None = None,
- trigger_to_nodes: Mapping[str, Sequence[str]] | None = None,
- name: str = "LangGraph",
- **deprecated_kwargs: Unpack[DeprecatedKwargs],
- ) -> None:
- if (
- config_type := deprecated_kwargs.get("config_type", MISSING)
- ) is not MISSING:
- warnings.warn(
- "`config_type` is deprecated and will be removed. Please use `context_schema` instead.",
- category=LangGraphDeprecatedSinceV10,
- stacklevel=2,
- )
- if context_schema is None:
- context_schema = cast(type[ContextT], config_type)
- self.nodes = {
- k: v.build() if isinstance(v, NodeBuilder) else v for k, v in nodes.items()
- }
- self.channels = channels or {}
- if TASKS in self.channels and not isinstance(self.channels[TASKS], Topic):
- raise ValueError(
- f"Channel '{TASKS}' is reserved and cannot be used in the graph."
- )
- else:
- self.channels[TASKS] = Topic(Send, accumulate=False)
- self.stream_mode = stream_mode
- self.stream_eager = stream_eager
- self.output_channels = output_channels
- self.stream_channels = stream_channels
- self.interrupt_after_nodes = interrupt_after_nodes
- self.interrupt_before_nodes = interrupt_before_nodes
- self.input_channels = input_channels
- self.step_timeout = step_timeout
- self.debug = debug if debug is not None else get_debug()
- self.checkpointer = checkpointer
- self.store = store
- self.cache = cache
- self.retry_policy = (
- (retry_policy,) if isinstance(retry_policy, RetryPolicy) else retry_policy
- )
- self.cache_policy = cache_policy
- self.context_schema = context_schema
- self.config = config
- self.trigger_to_nodes = trigger_to_nodes or {}
- self.name = name
- if auto_validate:
- self.validate()
- def get_graph(
- self, config: RunnableConfig | None = None, *, xray: int | bool = False
- ) -> Graph:
- """Return a drawable representation of the computation graph."""
- # gather subgraphs
- if xray:
- subgraphs = {
- k: v.get_graph(
- config,
- xray=xray if isinstance(xray, bool) or xray <= 0 else xray - 1,
- )
- for k, v in self.get_subgraphs()
- }
- else:
- subgraphs = {}
- return draw_graph(
- merge_configs(self.config, config),
- nodes=self.nodes,
- specs=self.channels,
- input_channels=self.input_channels,
- interrupt_after_nodes=self.interrupt_after_nodes,
- interrupt_before_nodes=self.interrupt_before_nodes,
- trigger_to_nodes=self.trigger_to_nodes,
- checkpointer=self.checkpointer,
- subgraphs=subgraphs,
- )
- async def aget_graph(
- self, config: RunnableConfig | None = None, *, xray: int | bool = False
- ) -> Graph:
- """Return a drawable representation of the computation graph."""
- # gather subgraphs
- if xray:
- subpregels: dict[str, PregelProtocol] = {
- k: v async for k, v in self.aget_subgraphs()
- }
- subgraphs = {
- k: v
- for k, v in zip(
- subpregels,
- await asyncio.gather(
- *(
- p.aget_graph(
- config,
- xray=xray
- if isinstance(xray, bool) or xray <= 0
- else xray - 1,
- )
- for p in subpregels.values()
- )
- ),
- )
- }
- else:
- subgraphs = {}
- return draw_graph(
- merge_configs(self.config, config),
- nodes=self.nodes,
- specs=self.channels,
- input_channels=self.input_channels,
- interrupt_after_nodes=self.interrupt_after_nodes,
- interrupt_before_nodes=self.interrupt_before_nodes,
- trigger_to_nodes=self.trigger_to_nodes,
- checkpointer=self.checkpointer,
- subgraphs=subgraphs,
- )
- def _repr_mimebundle_(self, **kwargs: Any) -> dict[str, Any]:
- """Mime bundle used by Jupyter to display the graph"""
- return {
- "text/plain": repr(self),
- "image/png": self.get_graph().draw_mermaid_png(),
- }
- def copy(self, update: dict[str, Any] | None = None) -> Self:
- attrs = {k: v for k, v in self.__dict__.items() if k != "__orig_class__"}
- attrs.update(update or {})
- return self.__class__(**attrs)
- def with_config(self, config: RunnableConfig | None = None, **kwargs: Any) -> Self:
- """Create a copy of the Pregel object with an updated config."""
- return self.copy(
- {"config": merge_configs(self.config, config, cast(RunnableConfig, kwargs))}
- )
- def validate(self) -> Self:
- validate_graph(
- self.nodes,
- {k: v for k, v in self.channels.items() if isinstance(v, BaseChannel)},
- {k: v for k, v in self.channels.items() if not isinstance(v, BaseChannel)},
- self.input_channels,
- self.output_channels,
- self.stream_channels,
- self.interrupt_after_nodes,
- self.interrupt_before_nodes,
- )
- self.trigger_to_nodes = _trigger_to_nodes(self.nodes)
- return self
- @deprecated(
- "`config_schema` is deprecated. Use `get_context_jsonschema` for the relevant schema instead.",
- category=None,
- )
- def config_schema(self, *, include: Sequence[str] | None = None) -> type[BaseModel]:
- warnings.warn(
- "`config_schema` is deprecated. Use `get_context_jsonschema` for the relevant schema instead.",
- category=LangGraphDeprecatedSinceV10,
- stacklevel=2,
- )
- include = include or []
- fields = {
- **(
- {"configurable": (self.context_schema, None)}
- if self.context_schema
- else {}
- ),
- **{
- field_name: (field_type, None)
- for field_name, field_type in get_type_hints(RunnableConfig).items()
- if field_name in [i for i in include if i != "configurable"]
- },
- }
- return create_model(self.get_name("Config"), field_definitions=fields)
- @deprecated(
- "`get_config_jsonschema` is deprecated. Use `get_context_jsonschema` instead.",
- category=None,
- )
- def get_config_jsonschema(
- self, *, include: Sequence[str] | None = None
- ) -> dict[str, Any]:
- warnings.warn(
- "`get_config_jsonschema` is deprecated. Use `get_context_jsonschema` instead.",
- category=LangGraphDeprecatedSinceV10,
- stacklevel=2,
- )
- with warnings.catch_warnings():
- warnings.filterwarnings("ignore", category=LangGraphDeprecatedSinceV10)
- schema = self.config_schema(include=include)
- return schema.model_json_schema()
- def get_context_jsonschema(self) -> dict[str, Any] | None:
- if (context_schema := self.context_schema) is None:
- return None
- if isclass(context_schema) and issubclass(context_schema, BaseModel):
- return context_schema.model_json_schema()
- elif is_typeddict(context_schema) or is_dataclass(context_schema):
- return TypeAdapter(context_schema).json_schema()
- else:
- raise ValueError(
- f"Invalid context schema type: {context_schema}. Must be a BaseModel, TypedDict or dataclass."
- )
- @property
- def InputType(self) -> Any:
- if isinstance(self.input_channels, str):
- channel = self.channels[self.input_channels]
- if isinstance(channel, BaseChannel):
- return channel.UpdateType
- def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
- config = merge_configs(self.config, config)
- if isinstance(self.input_channels, str):
- return super().get_input_schema(config)
- else:
- return create_model(
- self.get_name("Input"),
- field_definitions={
- k: (c.UpdateType, None)
- for k in self.input_channels or self.channels.keys()
- if (c := self.channels[k]) and isinstance(c, BaseChannel)
- },
- )
- def get_input_jsonschema(
- self, config: RunnableConfig | None = None
- ) -> dict[str, Any]:
- schema = self.get_input_schema(config)
- return schema.model_json_schema()
- @property
- def OutputType(self) -> Any:
- if isinstance(self.output_channels, str):
- channel = self.channels[self.output_channels]
- if isinstance(channel, BaseChannel):
- return channel.ValueType
- def get_output_schema(
- self, config: RunnableConfig | None = None
- ) -> type[BaseModel]:
- config = merge_configs(self.config, config)
- if isinstance(self.output_channels, str):
- return super().get_output_schema(config)
- else:
- return create_model(
- self.get_name("Output"),
- field_definitions={
- k: (c.ValueType, None)
- for k in self.output_channels
- if (c := self.channels[k]) and isinstance(c, BaseChannel)
- },
- )
- def get_output_jsonschema(
- self, config: RunnableConfig | None = None
- ) -> dict[str, Any]:
- schema = self.get_output_schema(config)
- return schema.model_json_schema()
- @property
- def stream_channels_list(self) -> Sequence[str]:
- stream_channels = self.stream_channels_asis
- return (
- [stream_channels] if isinstance(stream_channels, str) else stream_channels
- )
- @property
- def stream_channels_asis(self) -> str | Sequence[str]:
- return self.stream_channels or [
- k for k in self.channels if isinstance(self.channels[k], BaseChannel)
- ]
- def get_subgraphs(
- self, *, namespace: str | None = None, recurse: bool = False
- ) -> Iterator[tuple[str, PregelProtocol]]:
- """Get the subgraphs of the graph.
- Args:
- namespace: The namespace to filter the subgraphs by.
- recurse: Whether to recurse into the subgraphs.
- If `False`, only the immediate subgraphs will be returned.
- Returns:
- An iterator of the `(namespace, subgraph)` pairs.
- """
- for name, node in self.nodes.items():
- # filter by prefix
- if namespace is not None:
- if not namespace.startswith(name):
- continue
- # find the subgraph, if any
- graph = node.subgraphs[0] if node.subgraphs else None
- # if found, yield recursively
- if graph:
- if name == namespace:
- yield name, graph
- return # we found it, stop searching
- if namespace is None:
- yield name, graph
- if recurse and isinstance(graph, Pregel):
- if namespace is not None:
- namespace = namespace[len(name) + 1 :]
- yield from (
- (f"{name}{NS_SEP}{n}", s)
- for n, s in graph.get_subgraphs(
- namespace=namespace, recurse=recurse
- )
- )
- async def aget_subgraphs(
- self, *, namespace: str | None = None, recurse: bool = False
- ) -> AsyncIterator[tuple[str, PregelProtocol]]:
- """Get the subgraphs of the graph.
- Args:
- namespace: The namespace to filter the subgraphs by.
- recurse: Whether to recurse into the subgraphs.
- If `False`, only the immediate subgraphs will be returned.
- Returns:
- An iterator of the `(namespace, subgraph)` pairs.
- """
- for name, node in self.get_subgraphs(namespace=namespace, recurse=recurse):
- yield name, node
- def _migrate_checkpoint(self, checkpoint: Checkpoint) -> None:
- """Migrate a saved checkpoint to new channel layout."""
- if checkpoint["v"] < 4 and checkpoint.get("pending_sends"):
- pending_sends: list[Send] = checkpoint.pop("pending_sends")
- checkpoint["channel_values"][TASKS] = pending_sends
- checkpoint["channel_versions"][TASKS] = max(
- checkpoint["channel_versions"].values()
- )
- def _prepare_state_snapshot(
- self,
- config: RunnableConfig,
- saved: CheckpointTuple | None,
- recurse: BaseCheckpointSaver | None = None,
- apply_pending_writes: bool = False,
- ) -> StateSnapshot:
- if not saved:
- return StateSnapshot(
- values={},
- next=(),
- config=config,
- metadata=None,
- created_at=None,
- parent_config=None,
- tasks=(),
- interrupts=(),
- )
- # migrate checkpoint if needed
- self._migrate_checkpoint(saved.checkpoint)
- step = saved.metadata.get("step", -1) + 1
- stop = step + 2
- channels, managed = channels_from_checkpoint(
- self.channels,
- saved.checkpoint,
- )
- # tasks for this checkpoint
- next_tasks = prepare_next_tasks(
- saved.checkpoint,
- saved.pending_writes or [],
- self.nodes,
- channels,
- managed,
- saved.config,
- step,
- stop,
- for_execution=True,
- store=self.store,
- checkpointer=(
- self.checkpointer
- if isinstance(self.checkpointer, BaseCheckpointSaver)
- else None
- ),
- manager=None,
- )
- # get the subgraphs
- subgraphs = dict(self.get_subgraphs())
- parent_ns = saved.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
- task_states: dict[str, RunnableConfig | StateSnapshot] = {}
- for task in next_tasks.values():
- if task.name not in subgraphs:
- continue
- # assemble checkpoint_ns for this task
- task_ns = f"{task.name}{NS_END}{task.id}"
- if parent_ns:
- task_ns = f"{parent_ns}{NS_SEP}{task_ns}"
- if not recurse:
- # set config as signal that subgraph checkpoints exist
- config = {
- CONF: {
- "thread_id": saved.config[CONF]["thread_id"],
- CONFIG_KEY_CHECKPOINT_NS: task_ns,
- }
- }
- task_states[task.id] = config
- else:
- # get the state of the subgraph
- config = {
- CONF: {
- CONFIG_KEY_CHECKPOINTER: recurse,
- "thread_id": saved.config[CONF]["thread_id"],
- CONFIG_KEY_CHECKPOINT_NS: task_ns,
- }
- }
- task_states[task.id] = subgraphs[task.name].get_state(
- config, subgraphs=True
- )
- # apply pending writes
- if null_writes := [
- w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID
- ]:
- apply_writes(
- saved.checkpoint,
- channels,
- [PregelTaskWrites((), INPUT, null_writes, [])],
- None,
- self.trigger_to_nodes,
- )
- if apply_pending_writes and saved.pending_writes:
- for tid, k, v in saved.pending_writes:
- if k in (ERROR, INTERRUPT):
- continue
- if tid not in next_tasks:
- continue
- next_tasks[tid].writes.append((k, v))
- if tasks := [t for t in next_tasks.values() if t.writes]:
- apply_writes(
- saved.checkpoint, channels, tasks, None, self.trigger_to_nodes
- )
- tasks_with_writes = tasks_w_writes(
- next_tasks.values(),
- saved.pending_writes,
- task_states,
- self.stream_channels_asis,
- )
- # assemble the state snapshot
- return StateSnapshot(
- read_channels(channels, self.stream_channels_asis),
- tuple(t.name for t in next_tasks.values() if not t.writes),
- patch_checkpoint_map(saved.config, saved.metadata),
- saved.metadata,
- saved.checkpoint["ts"],
- patch_checkpoint_map(saved.parent_config, saved.metadata),
- tasks_with_writes,
- tuple([i for task in tasks_with_writes for i in task.interrupts]),
- )
- async def _aprepare_state_snapshot(
- self,
- config: RunnableConfig,
- saved: CheckpointTuple | None,
- recurse: BaseCheckpointSaver | None = None,
- apply_pending_writes: bool = False,
- ) -> StateSnapshot:
- if not saved:
- return StateSnapshot(
- values={},
- next=(),
- config=config,
- metadata=None,
- created_at=None,
- parent_config=None,
- tasks=(),
- interrupts=(),
- )
- # migrate checkpoint if needed
- self._migrate_checkpoint(saved.checkpoint)
- step = saved.metadata.get("step", -1) + 1
- stop = step + 2
- channels, managed = channels_from_checkpoint(
- self.channels,
- saved.checkpoint,
- )
- # tasks for this checkpoint
- next_tasks = prepare_next_tasks(
- saved.checkpoint,
- saved.pending_writes or [],
- self.nodes,
- channels,
- managed,
- saved.config,
- step,
- stop,
- for_execution=True,
- store=self.store,
- checkpointer=(
- self.checkpointer
- if isinstance(self.checkpointer, BaseCheckpointSaver)
- else None
- ),
- manager=None,
- )
- # get the subgraphs
- subgraphs = {n: g async for n, g in self.aget_subgraphs()}
- parent_ns = saved.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
- task_states: dict[str, RunnableConfig | StateSnapshot] = {}
- for task in next_tasks.values():
- if task.name not in subgraphs:
- continue
- # assemble checkpoint_ns for this task
- task_ns = f"{task.name}{NS_END}{task.id}"
- if parent_ns:
- task_ns = f"{parent_ns}{NS_SEP}{task_ns}"
- if not recurse:
- # set config as signal that subgraph checkpoints exist
- config = {
- CONF: {
- "thread_id": saved.config[CONF]["thread_id"],
- CONFIG_KEY_CHECKPOINT_NS: task_ns,
- }
- }
- task_states[task.id] = config
- else:
- # get the state of the subgraph
- config = {
- CONF: {
- CONFIG_KEY_CHECKPOINTER: recurse,
- "thread_id": saved.config[CONF]["thread_id"],
- CONFIG_KEY_CHECKPOINT_NS: task_ns,
- }
- }
- task_states[task.id] = await subgraphs[task.name].aget_state(
- config, subgraphs=True
- )
- # apply pending writes
- if null_writes := [
- w[1:] for w in saved.pending_writes or [] if w[0] == NULL_TASK_ID
- ]:
- apply_writes(
- saved.checkpoint,
- channels,
- [PregelTaskWrites((), INPUT, null_writes, [])],
- None,
- self.trigger_to_nodes,
- )
- if apply_pending_writes and saved.pending_writes:
- for tid, k, v in saved.pending_writes:
- if k in (ERROR, INTERRUPT):
- continue
- if tid not in next_tasks:
- continue
- next_tasks[tid].writes.append((k, v))
- if tasks := [t for t in next_tasks.values() if t.writes]:
- apply_writes(
- saved.checkpoint, channels, tasks, None, self.trigger_to_nodes
- )
- tasks_with_writes = tasks_w_writes(
- next_tasks.values(),
- saved.pending_writes,
- task_states,
- self.stream_channels_asis,
- )
- # assemble the state snapshot
- return StateSnapshot(
- read_channels(channels, self.stream_channels_asis),
- tuple(t.name for t in next_tasks.values() if not t.writes),
- patch_checkpoint_map(saved.config, saved.metadata),
- saved.metadata,
- saved.checkpoint["ts"],
- patch_checkpoint_map(saved.parent_config, saved.metadata),
- tasks_with_writes,
- tuple([i for task in tasks_with_writes for i in task.interrupts]),
- )
- def get_state(
- self, config: RunnableConfig, *, subgraphs: bool = False
- ) -> StateSnapshot:
- """Get the current state of the graph."""
- checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get(
- CONFIG_KEY_CHECKPOINTER, self.checkpointer
- )
- if not checkpointer:
- raise ValueError("No checkpointer set")
- if (
- checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
- ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
- # remove task_ids from checkpoint_ns
- recast = recast_checkpoint_ns(checkpoint_ns)
- # find the subgraph with the matching name
- for _, pregel in self.get_subgraphs(namespace=recast, recurse=True):
- return pregel.get_state(
- patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
- subgraphs=subgraphs,
- )
- else:
- raise ValueError(f"Subgraph {recast} not found")
- config = merge_configs(self.config, config) if self.config else config
- if self.checkpointer is True:
- ns = cast(str, config[CONF][CONFIG_KEY_CHECKPOINT_NS])
- config = merge_configs(
- config, {CONF: {CONFIG_KEY_CHECKPOINT_NS: recast_checkpoint_ns(ns)}}
- )
- thread_id = config[CONF][CONFIG_KEY_THREAD_ID]
- if not isinstance(thread_id, str):
- config[CONF][CONFIG_KEY_THREAD_ID] = str(thread_id)
- saved = checkpointer.get_tuple(config)
- return self._prepare_state_snapshot(
- config,
- saved,
- recurse=checkpointer if subgraphs else None,
- apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF],
- )
- async def aget_state(
- self, config: RunnableConfig, *, subgraphs: bool = False
- ) -> StateSnapshot:
- """Get the current state of the graph."""
- checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get(
- CONFIG_KEY_CHECKPOINTER, self.checkpointer
- )
- if not checkpointer:
- raise ValueError("No checkpointer set")
- if (
- checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
- ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
- # remove task_ids from checkpoint_ns
- recast = recast_checkpoint_ns(checkpoint_ns)
- # find the subgraph with the matching name
- async for _, pregel in self.aget_subgraphs(namespace=recast, recurse=True):
- return await pregel.aget_state(
- patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
- subgraphs=subgraphs,
- )
- else:
- raise ValueError(f"Subgraph {recast} not found")
- config = merge_configs(self.config, config) if self.config else config
- if self.checkpointer is True:
- ns = cast(str, config[CONF][CONFIG_KEY_CHECKPOINT_NS])
- config = merge_configs(
- config, {CONF: {CONFIG_KEY_CHECKPOINT_NS: recast_checkpoint_ns(ns)}}
- )
- thread_id = config[CONF][CONFIG_KEY_THREAD_ID]
- if not isinstance(thread_id, str):
- config[CONF][CONFIG_KEY_THREAD_ID] = str(thread_id)
- saved = await checkpointer.aget_tuple(config)
- return await self._aprepare_state_snapshot(
- config,
- saved,
- recurse=checkpointer if subgraphs else None,
- apply_pending_writes=CONFIG_KEY_CHECKPOINT_ID not in config[CONF],
- )
- def get_state_history(
- self,
- config: RunnableConfig,
- *,
- filter: dict[str, Any] | None = None,
- before: RunnableConfig | None = None,
- limit: int | None = None,
- ) -> Iterator[StateSnapshot]:
- """Get the history of the state of the graph."""
- config = ensure_config(config)
- checkpointer: BaseCheckpointSaver | None = config[CONF].get(
- CONFIG_KEY_CHECKPOINTER, self.checkpointer
- )
- if not checkpointer:
- raise ValueError("No checkpointer set")
- if (
- checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
- ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
- # remove task_ids from checkpoint_ns
- recast = recast_checkpoint_ns(checkpoint_ns)
- # find the subgraph with the matching name
- for _, pregel in self.get_subgraphs(namespace=recast, recurse=True):
- yield from pregel.get_state_history(
- patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
- filter=filter,
- before=before,
- limit=limit,
- )
- return
- else:
- raise ValueError(f"Subgraph {recast} not found")
- config = merge_configs(
- self.config,
- config,
- {
- CONF: {
- CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns,
- CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID]),
- }
- },
- )
- # eagerly consume list() to avoid holding up the db cursor
- for checkpoint_tuple in list(
- checkpointer.list(config, before=before, limit=limit, filter=filter)
- ):
- yield self._prepare_state_snapshot(
- checkpoint_tuple.config, checkpoint_tuple
- )
- async def aget_state_history(
- self,
- config: RunnableConfig,
- *,
- filter: dict[str, Any] | None = None,
- before: RunnableConfig | None = None,
- limit: int | None = None,
- ) -> AsyncIterator[StateSnapshot]:
- """Asynchronously get the history of the state of the graph."""
- config = ensure_config(config)
- checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get(
- CONFIG_KEY_CHECKPOINTER, self.checkpointer
- )
- if not checkpointer:
- raise ValueError("No checkpointer set")
- if (
- checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
- ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
- # remove task_ids from checkpoint_ns
- recast = recast_checkpoint_ns(checkpoint_ns)
- # find the subgraph with the matching name
- async for _, pregel in self.aget_subgraphs(namespace=recast, recurse=True):
- async for state in pregel.aget_state_history(
- patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
- filter=filter,
- before=before,
- limit=limit,
- ):
- yield state
- return
- else:
- raise ValueError(f"Subgraph {recast} not found")
- config = merge_configs(
- self.config,
- config,
- {
- CONF: {
- CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns,
- CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID]),
- }
- },
- )
- # eagerly consume list() to avoid holding up the db cursor
- for checkpoint_tuple in [
- c
- async for c in checkpointer.alist(
- config, before=before, limit=limit, filter=filter
- )
- ]:
- yield await self._aprepare_state_snapshot(
- checkpoint_tuple.config, checkpoint_tuple
- )
- def bulk_update_state(
- self,
- config: RunnableConfig,
- supersteps: Sequence[Sequence[StateUpdate]],
- ) -> RunnableConfig:
- """Apply updates to the graph state in bulk. Requires a checkpointer to be set.
- Args:
- config: The config to apply the updates to.
- supersteps: A list of supersteps, each including a list of updates to apply sequentially to a graph state.
- Each update is a tuple of the form `(values, as_node, task_id)` where `task_id` is optional.
- Raises:
- ValueError: If no checkpointer is set or no updates are provided.
- InvalidUpdateError: If an invalid update is provided.
- Returns:
- RunnableConfig: The updated config.
- """
- checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get(
- CONFIG_KEY_CHECKPOINTER, self.checkpointer
- )
- if not checkpointer:
- raise ValueError("No checkpointer set")
- if len(supersteps) == 0:
- raise ValueError("No supersteps provided")
- if any(len(u) == 0 for u in supersteps):
- raise ValueError("No updates provided")
- # delegate to subgraph
- if (
- checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
- ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
- # remove task_ids from checkpoint_ns
- recast = recast_checkpoint_ns(checkpoint_ns)
- # find the subgraph with the matching name
- for _, pregel in self.get_subgraphs(namespace=recast, recurse=True):
- return pregel.bulk_update_state(
- patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
- supersteps,
- )
- else:
- raise ValueError(f"Subgraph {recast} not found")
- def perform_superstep(
- input_config: RunnableConfig, updates: Sequence[StateUpdate]
- ) -> RunnableConfig:
- # get last checkpoint
- config = ensure_config(self.config, input_config)
- saved = checkpointer.get_tuple(config)
- if saved is not None:
- self._migrate_checkpoint(saved.checkpoint)
- checkpoint = (
- copy_checkpoint(saved.checkpoint) if saved else empty_checkpoint()
- )
- checkpoint_previous_versions = (
- saved.checkpoint["channel_versions"].copy() if saved else {}
- )
- step = saved.metadata.get("step", -1) if saved else -1
- # merge configurable fields with previous checkpoint config
- checkpoint_config = patch_configurable(
- config,
- {
- CONFIG_KEY_CHECKPOINT_NS: config[CONF].get(
- CONFIG_KEY_CHECKPOINT_NS, ""
- )
- },
- )
- if saved:
- checkpoint_config = patch_configurable(config, saved.config[CONF])
- channels, managed = channels_from_checkpoint(
- self.channels,
- checkpoint,
- )
- values, as_node = updates[0][:2]
- # no values as END, just clear all tasks
- if values is None and as_node == END:
- if len(updates) > 1:
- raise InvalidUpdateError(
- "Cannot apply multiple updates when clearing state"
- )
- if saved is not None:
- # tasks for this checkpoint
- next_tasks = prepare_next_tasks(
- checkpoint,
- saved.pending_writes or [],
- self.nodes,
- channels,
- managed,
- saved.config,
- step + 1,
- step + 3,
- for_execution=True,
- store=self.store,
- checkpointer=checkpointer,
- manager=None,
- )
- # apply null writes
- if null_writes := [
- w[1:]
- for w in saved.pending_writes or []
- if w[0] == NULL_TASK_ID
- ]:
- apply_writes(
- checkpoint,
- channels,
- [PregelTaskWrites((), INPUT, null_writes, [])],
- checkpointer.get_next_version,
- self.trigger_to_nodes,
- )
- # apply writes from tasks that already ran
- for tid, k, v in saved.pending_writes or []:
- if k in (ERROR, INTERRUPT):
- continue
- if tid not in next_tasks:
- continue
- next_tasks[tid].writes.append((k, v))
- # clear all current tasks
- apply_writes(
- checkpoint,
- channels,
- next_tasks.values(),
- checkpointer.get_next_version,
- self.trigger_to_nodes,
- )
- # save checkpoint
- next_config = checkpointer.put(
- checkpoint_config,
- create_checkpoint(checkpoint, channels, step),
- {
- "source": "update",
- "step": step + 1,
- "parents": saved.metadata.get("parents", {}) if saved else {},
- },
- get_new_channel_versions(
- checkpoint_previous_versions,
- checkpoint["channel_versions"],
- ),
- )
- return patch_checkpoint_map(
- next_config, saved.metadata if saved else None
- )
- # act as an input
- if as_node == INPUT:
- if len(updates) > 1:
- raise InvalidUpdateError(
- "Cannot apply multiple updates when updating as input"
- )
- if input_writes := deque(map_input(self.input_channels, values)):
- apply_writes(
- checkpoint,
- channels,
- [PregelTaskWrites((), INPUT, input_writes, [])],
- checkpointer.get_next_version,
- self.trigger_to_nodes,
- )
- # apply input write to channels
- next_step = (
- step + 1
- if saved and saved.metadata.get("step") is not None
- else -1
- )
- next_config = checkpointer.put(
- checkpoint_config,
- create_checkpoint(checkpoint, channels, next_step),
- {
- "source": "input",
- "step": next_step,
- "parents": saved.metadata.get("parents", {})
- if saved
- else {},
- },
- get_new_channel_versions(
- checkpoint_previous_versions,
- checkpoint["channel_versions"],
- ),
- )
- # store the writes
- checkpointer.put_writes(
- next_config,
- input_writes,
- str(uuid5(UUID(checkpoint["id"]), INPUT)),
- )
- return patch_checkpoint_map(
- next_config, saved.metadata if saved else None
- )
- else:
- raise InvalidUpdateError(
- f"Received no input writes for {self.input_channels}"
- )
- # copy checkpoint
- if as_node == "__copy__":
- if len(updates) > 1:
- raise InvalidUpdateError(
- "Cannot copy checkpoint with multiple updates"
- )
- if saved is None:
- raise InvalidUpdateError("Cannot copy a non-existent checkpoint")
- next_checkpoint = create_checkpoint(checkpoint, None, step)
- # copy checkpoint
- next_config = checkpointer.put(
- saved.parent_config
- or patch_configurable(
- saved.config, {CONFIG_KEY_CHECKPOINT_ID: None}
- ),
- next_checkpoint,
- {
- "source": "fork",
- "step": step + 1,
- "parents": saved.metadata.get("parents", {}),
- },
- {},
- )
- # we want to both clone a checkpoint and update state in one go.
- # reuse the same task ID if possible.
- if isinstance(values, list) and len(values) > 0:
- # figure out the task IDs for the next update checkpoint
- next_tasks = prepare_next_tasks(
- next_checkpoint,
- saved.pending_writes or [],
- self.nodes,
- channels,
- managed,
- next_config,
- step + 2,
- step + 4,
- for_execution=True,
- store=self.store,
- checkpointer=checkpointer,
- manager=None,
- )
- tasks_group_by = defaultdict(list)
- user_group_by: dict[str, list[StateUpdate]] = defaultdict(list)
- for task in next_tasks.values():
- tasks_group_by[task.name].append(task.id)
- for item in values:
- if not isinstance(item, Sequence):
- raise InvalidUpdateError(
- f"Invalid update item: {item} when copying checkpoint"
- )
- values, as_node = item[:2]
- user_group = user_group_by[as_node]
- tasks_group = tasks_group_by[as_node]
- target_idx = len(user_group)
- task_id = (
- tasks_group[target_idx]
- if target_idx < len(tasks_group)
- else None
- )
- user_group_by[as_node].append(
- StateUpdate(values=values, as_node=as_node, task_id=task_id)
- )
- return perform_superstep(
- patch_checkpoint_map(next_config, saved.metadata),
- [item for lst in user_group_by.values() for item in lst],
- )
- return patch_checkpoint_map(next_config, saved.metadata)
- # task ids can be provided in the StateUpdate, but if not,
- # we use the task id generated by prepare_next_tasks
- node_to_task_ids: dict[str, deque[str]] = defaultdict(deque)
- if saved is not None and saved.pending_writes is not None:
- # we call prepare_next_tasks to discover the task IDs that
- # would have been generated, so we can reuse them and
- # properly populate task.result in state history
- next_tasks = prepare_next_tasks(
- checkpoint,
- saved.pending_writes,
- self.nodes,
- channels,
- managed,
- saved.config,
- step + 1,
- step + 3,
- for_execution=True,
- store=self.store,
- checkpointer=checkpointer,
- manager=None,
- )
- # collect task ids to reuse so we can properly attach task results
- for t in next_tasks.values():
- node_to_task_ids[t.name].append(t.id)
- valid_updates: list[tuple[str, dict[str, Any] | None, str | None]] = []
- if len(updates) == 1:
- values, as_node, task_id = updates[0]
- # find last node that updated the state, if not provided
- if as_node is None and len(self.nodes) == 1:
- as_node = tuple(self.nodes)[0]
- elif as_node is None and not any(
- v
- for vv in checkpoint["versions_seen"].values()
- for v in vv.values()
- ):
- if (
- isinstance(self.input_channels, str)
- and self.input_channels in self.nodes
- ):
- as_node = self.input_channels
- elif as_node is None:
- last_seen_by_node = sorted(
- (v, n)
- for n, seen in checkpoint["versions_seen"].items()
- if n in self.nodes
- for v in seen.values()
- )
- # if two nodes updated the state at the same time, it's ambiguous
- if last_seen_by_node:
- if len(last_seen_by_node) == 1:
- as_node = last_seen_by_node[0][1]
- elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]:
- as_node = last_seen_by_node[-1][1]
- if as_node is None:
- raise InvalidUpdateError("Ambiguous update, specify as_node")
- if as_node not in self.nodes:
- raise InvalidUpdateError(f"Node {as_node} does not exist")
- valid_updates.append((as_node, values, task_id))
- else:
- for values, as_node, task_id in updates:
- if as_node is None:
- raise InvalidUpdateError(
- "as_node is required when applying multiple updates"
- )
- if as_node not in self.nodes:
- raise InvalidUpdateError(f"Node {as_node} does not exist")
- valid_updates.append((as_node, values, task_id))
- run_tasks: list[PregelTaskWrites] = []
- run_task_ids: list[str] = []
- for as_node, values, provided_task_id in valid_updates:
- # create task to run all writers of the chosen node
- writers = self.nodes[as_node].flat_writers
- if not writers:
- raise InvalidUpdateError(f"Node {as_node} has no writers")
- writes: deque[tuple[str, Any]] = deque()
- task = PregelTaskWrites((), as_node, writes, [INTERRUPT])
- # get the task ids that were prepared for this node
- # if a task id was provided in the StateUpdate, we use it
- # otherwise, we use the next available task id
- prepared_task_ids = node_to_task_ids.get(as_node, deque())
- task_id = provided_task_id or (
- prepared_task_ids.popleft()
- if prepared_task_ids
- else str(uuid5(UUID(checkpoint["id"]), INTERRUPT))
- )
- run_tasks.append(task)
- run_task_ids.append(task_id)
- run = RunnableSequence(*writers) if len(writers) > 1 else writers[0]
- # execute task
- run.invoke(
- values,
- patch_config(
- config,
- run_name=self.name + "UpdateState",
- configurable={
- # deque.extend is thread-safe
- CONFIG_KEY_SEND: writes.extend,
- CONFIG_KEY_TASK_ID: task_id,
- CONFIG_KEY_READ: partial(
- local_read,
- _scratchpad(
- None,
- [],
- task_id,
- "",
- None,
- step,
- step + 2,
- ),
- channels,
- managed,
- task,
- ),
- },
- ),
- )
- # save task writes
- for task_id, task in zip(run_task_ids, run_tasks):
- # channel writes are saved to current checkpoint
- channel_writes = [w for w in task.writes if w[0] != PUSH]
- if saved and channel_writes:
- checkpointer.put_writes(checkpoint_config, channel_writes, task_id)
- # apply to checkpoint and save
- apply_writes(
- checkpoint,
- channels,
- run_tasks,
- checkpointer.get_next_version,
- self.trigger_to_nodes,
- )
- checkpoint = create_checkpoint(checkpoint, channels, step + 1)
- next_config = checkpointer.put(
- checkpoint_config,
- checkpoint,
- {
- "source": "update",
- "step": step + 1,
- "parents": saved.metadata.get("parents", {}) if saved else {},
- },
- get_new_channel_versions(
- checkpoint_previous_versions, checkpoint["channel_versions"]
- ),
- )
- for task_id, task in zip(run_task_ids, run_tasks):
- # save push writes
- if push_writes := [w for w in task.writes if w[0] == PUSH]:
- checkpointer.put_writes(next_config, push_writes, task_id)
- return patch_checkpoint_map(next_config, saved.metadata if saved else None)
- current_config = patch_configurable(
- config, {CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID])}
- )
- for superstep in supersteps:
- current_config = perform_superstep(current_config, superstep)
- return current_config
- async def abulk_update_state(
- self,
- config: RunnableConfig,
- supersteps: Sequence[Sequence[StateUpdate]],
- ) -> RunnableConfig:
- """Asynchronously apply updates to the graph state in bulk. Requires a checkpointer to be set.
- Args:
- config: The config to apply the updates to.
- supersteps: A list of supersteps, each including a list of updates to apply sequentially to a graph state.
- Each update is a tuple of the form `(values, as_node, task_id)` where `task_id` is optional.
- Raises:
- ValueError: If no checkpointer is set or no updates are provided.
- InvalidUpdateError: If an invalid update is provided.
- Returns:
- RunnableConfig: The updated config.
- """
- checkpointer: BaseCheckpointSaver | None = ensure_config(config)[CONF].get(
- CONFIG_KEY_CHECKPOINTER, self.checkpointer
- )
- if not checkpointer:
- raise ValueError("No checkpointer set")
- if len(supersteps) == 0:
- raise ValueError("No supersteps provided")
- if any(len(u) == 0 for u in supersteps):
- raise ValueError("No updates provided")
- # delegate to subgraph
- if (
- checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
- ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]:
- # remove task_ids from checkpoint_ns
- recast = recast_checkpoint_ns(checkpoint_ns)
- # find the subgraph with the matching name
- async for _, pregel in self.aget_subgraphs(namespace=recast, recurse=True):
- return await pregel.abulk_update_state(
- patch_configurable(config, {CONFIG_KEY_CHECKPOINTER: checkpointer}),
- supersteps,
- )
- else:
- raise ValueError(f"Subgraph {recast} not found")
- async def aperform_superstep(
- input_config: RunnableConfig, updates: Sequence[StateUpdate]
- ) -> RunnableConfig:
- # get last checkpoint
- config = ensure_config(self.config, input_config)
- saved = await checkpointer.aget_tuple(config)
- if saved is not None:
- self._migrate_checkpoint(saved.checkpoint)
- checkpoint = (
- copy_checkpoint(saved.checkpoint) if saved else empty_checkpoint()
- )
- checkpoint_previous_versions = (
- saved.checkpoint["channel_versions"].copy() if saved else {}
- )
- step = saved.metadata.get("step", -1) if saved else -1
- # merge configurable fields with previous checkpoint config
- checkpoint_config = patch_configurable(
- config,
- {
- CONFIG_KEY_CHECKPOINT_NS: config[CONF].get(
- CONFIG_KEY_CHECKPOINT_NS, ""
- )
- },
- )
- if saved:
- checkpoint_config = patch_configurable(config, saved.config[CONF])
- channels, managed = channels_from_checkpoint(
- self.channels,
- checkpoint,
- )
- values, as_node = updates[0][:2]
- # no values, just clear all tasks
- if values is None and as_node == END:
- if len(updates) > 1:
- raise InvalidUpdateError(
- "Cannot apply multiple updates when clearing state"
- )
- if saved is not None:
- # tasks for this checkpoint
- next_tasks = prepare_next_tasks(
- checkpoint,
- saved.pending_writes or [],
- self.nodes,
- channels,
- managed,
- saved.config,
- step + 1,
- step + 3,
- for_execution=True,
- store=self.store,
- checkpointer=checkpointer,
- manager=None,
- )
- # apply null writes
- if null_writes := [
- w[1:]
- for w in saved.pending_writes or []
- if w[0] == NULL_TASK_ID
- ]:
- apply_writes(
- checkpoint,
- channels,
- [PregelTaskWrites((), INPUT, null_writes, [])],
- checkpointer.get_next_version,
- self.trigger_to_nodes,
- )
- # apply writes from tasks that already ran
- for tid, k, v in saved.pending_writes or []:
- if k in (ERROR, INTERRUPT):
- continue
- if tid not in next_tasks:
- continue
- next_tasks[tid].writes.append((k, v))
- # clear all current tasks
- apply_writes(
- checkpoint,
- channels,
- next_tasks.values(),
- checkpointer.get_next_version,
- self.trigger_to_nodes,
- )
- # save checkpoint
- next_config = await checkpointer.aput(
- checkpoint_config,
- create_checkpoint(checkpoint, channels, step),
- {
- "source": "update",
- "step": step + 1,
- "parents": saved.metadata.get("parents", {}) if saved else {},
- },
- get_new_channel_versions(
- checkpoint_previous_versions, checkpoint["channel_versions"]
- ),
- )
- return patch_checkpoint_map(
- next_config, saved.metadata if saved else None
- )
- # act as an input
- if as_node == INPUT:
- if len(updates) > 1:
- raise InvalidUpdateError(
- "Cannot apply multiple updates when updating as input"
- )
- if input_writes := deque(map_input(self.input_channels, values)):
- apply_writes(
- checkpoint,
- channels,
- [PregelTaskWrites((), INPUT, input_writes, [])],
- checkpointer.get_next_version,
- self.trigger_to_nodes,
- )
- # apply input write to channels
- next_step = (
- step + 1
- if saved and saved.metadata.get("step") is not None
- else -1
- )
- next_config = await checkpointer.aput(
- checkpoint_config,
- create_checkpoint(checkpoint, channels, next_step),
- {
- "source": "input",
- "step": next_step,
- "parents": saved.metadata.get("parents", {})
- if saved
- else {},
- },
- get_new_channel_versions(
- checkpoint_previous_versions,
- checkpoint["channel_versions"],
- ),
- )
- # store the writes
- await checkpointer.aput_writes(
- next_config,
- input_writes,
- str(uuid5(UUID(checkpoint["id"]), INPUT)),
- )
- return patch_checkpoint_map(
- next_config, saved.metadata if saved else None
- )
- else:
- raise InvalidUpdateError(
- f"Received no input writes for {self.input_channels}"
- )
- # no values, copy checkpoint
- if as_node == "__copy__":
- if len(updates) > 1:
- raise InvalidUpdateError(
- "Cannot copy checkpoint with multiple updates"
- )
- if saved is None:
- raise InvalidUpdateError("Cannot copy a non-existent checkpoint")
- next_checkpoint = create_checkpoint(checkpoint, None, step)
- # copy checkpoint
- next_config = await checkpointer.aput(
- saved.parent_config
- or patch_configurable(
- saved.config, {CONFIG_KEY_CHECKPOINT_ID: None}
- ),
- next_checkpoint,
- {
- "source": "fork",
- "step": step + 1,
- "parents": saved.metadata.get("parents", {}),
- },
- {},
- )
- # we want to both clone a checkpoint and update state in one go.
- # reuse the same task ID if possible.
- if isinstance(values, list) and len(values) > 0:
- # figure out the task IDs for the next update checkpoint
- next_tasks = prepare_next_tasks(
- next_checkpoint,
- saved.pending_writes or [],
- self.nodes,
- channels,
- managed,
- next_config,
- step + 2,
- step + 4,
- for_execution=True,
- store=self.store,
- checkpointer=checkpointer,
- manager=None,
- )
- tasks_group_by = defaultdict(list)
- user_group_by: dict[str, list[StateUpdate]] = defaultdict(list)
- for task in next_tasks.values():
- tasks_group_by[task.name].append(task.id)
- for item in values:
- if not isinstance(item, Sequence):
- raise InvalidUpdateError(
- f"Invalid update item: {item} when copying checkpoint"
- )
- values, as_node = item[:2]
- user_group = user_group_by[as_node]
- tasks_group = tasks_group_by[as_node]
- target_idx = len(user_group)
- task_id = (
- tasks_group[target_idx]
- if target_idx < len(tasks_group)
- else None
- )
- user_group_by[as_node].append(
- StateUpdate(values=values, as_node=as_node, task_id=task_id)
- )
- return await aperform_superstep(
- patch_checkpoint_map(next_config, saved.metadata),
- [item for lst in user_group_by.values() for item in lst],
- )
- return patch_checkpoint_map(
- next_config, saved.metadata if saved else None
- )
- # task ids can be provided in the StateUpdate, but if not,
- # we use the task id generated by prepare_next_tasks
- node_to_task_ids: dict[str, deque[str]] = defaultdict(deque)
- if saved is not None and saved.pending_writes is not None:
- # we call prepare_next_tasks to discover the task IDs that
- # would have been generated, so we can reuse them and
- # properly populate task.result in state history
- next_tasks = prepare_next_tasks(
- checkpoint,
- saved.pending_writes,
- self.nodes,
- channels,
- managed,
- saved.config,
- step + 1,
- step + 3,
- for_execution=True,
- store=self.store,
- checkpointer=checkpointer,
- manager=None,
- )
- # collect task ids to reuse so we can properly attach task results
- for t in next_tasks.values():
- node_to_task_ids[t.name].append(t.id)
- valid_updates: list[tuple[str, dict[str, Any] | None, str | None]] = []
- if len(updates) == 1:
- values, as_node, task_id = updates[0]
- # find last node that updated the state, if not provided
- if as_node is None and len(self.nodes) == 1:
- as_node = tuple(self.nodes)[0]
- elif as_node is None and not saved:
- if (
- isinstance(self.input_channels, str)
- and self.input_channels in self.nodes
- ):
- as_node = self.input_channels
- elif as_node is None:
- last_seen_by_node = sorted(
- (v, n)
- for n, seen in checkpoint["versions_seen"].items()
- if n in self.nodes
- for v in seen.values()
- )
- # if two nodes updated the state at the same time, it's ambiguous
- if last_seen_by_node:
- if len(last_seen_by_node) == 1:
- as_node = last_seen_by_node[0][1]
- elif last_seen_by_node[-1][0] != last_seen_by_node[-2][0]:
- as_node = last_seen_by_node[-1][1]
- if as_node is None:
- raise InvalidUpdateError("Ambiguous update, specify as_node")
- if as_node not in self.nodes:
- raise InvalidUpdateError(f"Node {as_node} does not exist")
- valid_updates.append((as_node, values, task_id))
- else:
- for values, as_node, task_id in updates:
- if as_node is None:
- raise InvalidUpdateError(
- "as_node is required when applying multiple updates"
- )
- if as_node not in self.nodes:
- raise InvalidUpdateError(f"Node {as_node} does not exist")
- valid_updates.append((as_node, values, task_id))
- run_tasks: list[PregelTaskWrites] = []
- run_task_ids: list[str] = []
- for as_node, values, provided_task_id in valid_updates:
- # create task to run all writers of the chosen node
- writers = self.nodes[as_node].flat_writers
- if not writers:
- raise InvalidUpdateError(f"Node {as_node} has no writers")
- writes: deque[tuple[str, Any]] = deque()
- task = PregelTaskWrites((), as_node, writes, [INTERRUPT])
- # get the task ids that were prepared for this node
- # if a task id was provided in the StateUpdate, we use it
- # otherwise, we use the next available task id
- prepared_task_ids = node_to_task_ids.get(as_node, deque())
- task_id = provided_task_id or (
- prepared_task_ids.popleft()
- if prepared_task_ids
- else str(uuid5(UUID(checkpoint["id"]), INTERRUPT))
- )
- run_tasks.append(task)
- run_task_ids.append(task_id)
- run = RunnableSequence(*writers) if len(writers) > 1 else writers[0]
- # execute task
- await run.ainvoke(
- values,
- patch_config(
- config,
- run_name=self.name + "UpdateState",
- configurable={
- # deque.extend is thread-safe
- CONFIG_KEY_SEND: writes.extend,
- CONFIG_KEY_TASK_ID: task_id,
- CONFIG_KEY_READ: partial(
- local_read,
- _scratchpad(
- None,
- [],
- task_id,
- "",
- None,
- step,
- step + 2,
- ),
- channels,
- managed,
- task,
- ),
- },
- ),
- )
- # save task writes
- for task_id, task in zip(run_task_ids, run_tasks):
- # channel writes are saved to current checkpoint
- channel_writes = [w for w in task.writes if w[0] != PUSH]
- if saved and channel_writes:
- await checkpointer.aput_writes(
- checkpoint_config, channel_writes, task_id
- )
- # apply to checkpoint and save
- apply_writes(
- checkpoint,
- channels,
- run_tasks,
- checkpointer.get_next_version,
- self.trigger_to_nodes,
- )
- checkpoint = create_checkpoint(checkpoint, channels, step + 1)
- # save checkpoint, after applying writes
- next_config = await checkpointer.aput(
- checkpoint_config,
- checkpoint,
- {
- "source": "update",
- "step": step + 1,
- "parents": saved.metadata.get("parents", {}) if saved else {},
- },
- get_new_channel_versions(
- checkpoint_previous_versions, checkpoint["channel_versions"]
- ),
- )
- for task_id, task in zip(run_task_ids, run_tasks):
- # save push writes
- if push_writes := [w for w in task.writes if w[0] == PUSH]:
- await checkpointer.aput_writes(next_config, push_writes, task_id)
- return patch_checkpoint_map(next_config, saved.metadata if saved else None)
- current_config = patch_configurable(
- config, {CONFIG_KEY_THREAD_ID: str(config[CONF][CONFIG_KEY_THREAD_ID])}
- )
- for superstep in supersteps:
- current_config = await aperform_superstep(current_config, superstep)
- return current_config
- def update_state(
- self,
- config: RunnableConfig,
- values: dict[str, Any] | Any | None,
- as_node: str | None = None,
- task_id: str | None = None,
- ) -> RunnableConfig:
- """Update the state of the graph with the given values, as if they came from
- node `as_node`. If `as_node` is not provided, it will be set to the last node
- that updated the state, if not ambiguous.
- """
- return self.bulk_update_state(config, [[StateUpdate(values, as_node, task_id)]])
- async def aupdate_state(
- self,
- config: RunnableConfig,
- values: dict[str, Any] | Any,
- as_node: str | None = None,
- task_id: str | None = None,
- ) -> RunnableConfig:
- """Asynchronously update the state of the graph with the given values, as if they came from
- node `as_node`. If `as_node` is not provided, it will be set to the last node
- that updated the state, if not ambiguous.
- """
- return await self.abulk_update_state(
- config, [[StateUpdate(values, as_node, task_id)]]
- )
- def _defaults(
- self,
- config: RunnableConfig,
- *,
- stream_mode: StreamMode | Sequence[StreamMode],
- print_mode: StreamMode | Sequence[StreamMode],
- output_keys: str | Sequence[str] | None,
- interrupt_before: All | Sequence[str] | None,
- interrupt_after: All | Sequence[str] | None,
- durability: Durability | None = None,
- ) -> tuple[
- set[StreamMode],
- str | Sequence[str],
- All | Sequence[str],
- All | Sequence[str],
- BaseCheckpointSaver | None,
- BaseStore | None,
- BaseCache | None,
- Durability,
- ]:
- if config["recursion_limit"] < 1:
- raise ValueError("recursion_limit must be at least 1")
- if output_keys is None:
- output_keys = self.stream_channels_asis
- else:
- validate_keys(output_keys, self.channels)
- interrupt_before = interrupt_before or self.interrupt_before_nodes
- interrupt_after = interrupt_after or self.interrupt_after_nodes
- if isinstance(stream_mode, str):
- stream_modes = {stream_mode}
- else:
- stream_modes = set(stream_mode)
- if isinstance(print_mode, str):
- stream_modes.add(print_mode)
- else:
- stream_modes.update(print_mode)
- if self.checkpointer is False:
- checkpointer: BaseCheckpointSaver | None = None
- elif CONFIG_KEY_CHECKPOINTER in config.get(CONF, {}):
- checkpointer = config[CONF][CONFIG_KEY_CHECKPOINTER]
- elif self.checkpointer is True:
- raise RuntimeError("checkpointer=True cannot be used for root graphs.")
- else:
- checkpointer = self.checkpointer
- if checkpointer and not config.get(CONF):
- raise ValueError(
- "Checkpointer requires one or more of the following 'configurable' "
- "keys: thread_id, checkpoint_ns, checkpoint_id"
- )
- if CONFIG_KEY_RUNTIME in config.get(CONF, {}):
- store: BaseStore | None = config[CONF][CONFIG_KEY_RUNTIME].store
- else:
- store = self.store
- if CONFIG_KEY_CACHE in config.get(CONF, {}):
- cache: BaseCache | None = config[CONF][CONFIG_KEY_CACHE]
- else:
- cache = self.cache
- if durability is None:
- durability = config.get(CONF, {}).get(CONFIG_KEY_DURABILITY, "async")
- return (
- stream_modes,
- output_keys,
- interrupt_before,
- interrupt_after,
- checkpointer,
- store,
- cache,
- durability,
- )
- def stream(
- self,
- input: InputT | Command | None,
- config: RunnableConfig | None = None,
- *,
- context: ContextT | None = None,
- stream_mode: StreamMode | Sequence[StreamMode] | None = None,
- print_mode: StreamMode | Sequence[StreamMode] = (),
- output_keys: str | Sequence[str] | None = None,
- interrupt_before: All | Sequence[str] | None = None,
- interrupt_after: All | Sequence[str] | None = None,
- durability: Durability | None = None,
- subgraphs: bool = False,
- debug: bool | None = None,
- **kwargs: Unpack[DeprecatedKwargs],
- ) -> Iterator[dict[str, Any] | Any]:
- """Stream graph steps for a single input.
- Args:
- input: The input to the graph.
- config: The configuration to use for the run.
- context: The static context to use for the run.
- !!! version-added "Added in version 0.6.0"
- stream_mode: The mode to stream output, defaults to `self.stream_mode`.
- Options are:
- - `"values"`: Emit all values in the state after each step, including interrupts.
- When used with functional API, values are emitted once at the end of the workflow.
- - `"updates"`: Emit only the node or task names and updates returned by the nodes or tasks after each step.
- If multiple updates are made in the same step (e.g. multiple nodes are run) then those updates are emitted separately.
- - `"custom"`: Emit custom data from inside nodes or tasks using `StreamWriter`.
- - `"messages"`: Emit LLM messages token-by-token together with metadata for any LLM invocations inside nodes or tasks.
- - Will be emitted as 2-tuples `(LLM token, metadata)`.
- - `"checkpoints"`: Emit an event when a checkpoint is created, in the same format as returned by `get_state()`.
- - `"tasks"`: Emit events when tasks start and finish, including their results and errors.
- - `"debug"`: Emit debug events with as much information as possible for each step.
- You can pass a list as the `stream_mode` parameter to stream multiple modes at once.
- The streamed outputs will be tuples of `(mode, data)`.
- See [LangGraph streaming guide](https://docs.langchain.com/oss/python/langgraph/streaming) for more details.
- print_mode: Accepts the same values as `stream_mode`, but only prints the output to the console, for debugging purposes.
- Does not affect the output of the graph in any way.
- output_keys: The keys to stream, defaults to all non-context channels.
- interrupt_before: Nodes to interrupt before, defaults to all nodes in the graph.
- interrupt_after: Nodes to interrupt after, defaults to all nodes in the graph.
- durability: The durability mode for the graph execution, defaults to `"async"`.
- Options are:
- - `"sync"`: Changes are persisted synchronously before the next step starts.
- - `"async"`: Changes are persisted asynchronously while the next step executes.
- - `"exit"`: Changes are persisted only when the graph exits.
- subgraphs: Whether to stream events from inside subgraphs, defaults to `False`.
- If `True`, the events will be emitted as tuples `(namespace, data)`,
- or `(namespace, mode, data)` if `stream_mode` is a list,
- where `namespace` is a tuple with the path to the node where a subgraph is invoked,
- e.g. `("parent_node:<task_id>", "child_node:<task_id>")`.
- See [LangGraph streaming guide](https://docs.langchain.com/oss/python/langgraph/streaming) for more details.
- Yields:
- The output of each step in the graph. The output shape depends on the `stream_mode`.
- """
- if (checkpoint_during := kwargs.get("checkpoint_during")) is not None:
- warnings.warn(
- "`checkpoint_during` is deprecated and will be removed. Please use `durability` instead.",
- category=LangGraphDeprecatedSinceV10,
- stacklevel=2,
- )
- if durability is not None:
- raise ValueError(
- "Cannot use both `checkpoint_during` and `durability` parameters. Please use `durability` instead."
- )
- durability = "async" if checkpoint_during else "exit"
- if stream_mode is None:
- # if being called as a node in another graph, default to values mode
- # but don't overwrite stream_mode arg if provided
- stream_mode = (
- "values"
- if config is not None and CONFIG_KEY_TASK_ID in config.get(CONF, {})
- else self.stream_mode
- )
- if debug or self.debug:
- print_mode = ["updates", "values"]
- stream = SyncQueue()
- config = ensure_config(self.config, config)
- callback_manager = get_callback_manager_for_config(config)
- run_manager = callback_manager.on_chain_start(
- None,
- input,
- name=config.get("run_name", self.get_name()),
- run_id=config.get("run_id"),
- )
- try:
- # assign defaults
- (
- stream_modes,
- output_keys,
- interrupt_before_,
- interrupt_after_,
- checkpointer,
- store,
- cache,
- durability_,
- ) = self._defaults(
- config,
- stream_mode=stream_mode,
- print_mode=print_mode,
- output_keys=output_keys,
- interrupt_before=interrupt_before,
- interrupt_after=interrupt_after,
- durability=durability,
- )
- if checkpointer is None and durability is not None:
- warnings.warn(
- "`durability` has no effect when no checkpointer is present.",
- )
- # set up subgraph checkpointing
- if self.checkpointer is True:
- ns = cast(str, config[CONF][CONFIG_KEY_CHECKPOINT_NS])
- config[CONF][CONFIG_KEY_CHECKPOINT_NS] = recast_checkpoint_ns(ns)
- # set up messages stream mode
- if "messages" in stream_modes:
- ns_ = cast(str | None, config[CONF].get(CONFIG_KEY_CHECKPOINT_NS))
- run_manager.inheritable_handlers.append(
- StreamMessagesHandler(
- stream.put,
- subgraphs,
- parent_ns=tuple(ns_.split(NS_SEP)) if ns_ else None,
- )
- )
- # set up custom stream mode
- if "custom" in stream_modes:
- def stream_writer(c: Any) -> None:
- stream.put(
- (
- tuple(
- get_config()[CONF][CONFIG_KEY_CHECKPOINT_NS].split(
- NS_SEP
- )[:-1]
- ),
- "custom",
- c,
- )
- )
- elif CONFIG_KEY_STREAM in config[CONF]:
- stream_writer = config[CONF][CONFIG_KEY_RUNTIME].stream_writer
- else:
- def stream_writer(c: Any) -> None:
- pass
- # set durability mode for subgraphs
- if durability is not None:
- config[CONF][CONFIG_KEY_DURABILITY] = durability_
- runtime = Runtime(
- context=_coerce_context(self.context_schema, context),
- store=store,
- stream_writer=stream_writer,
- previous=None,
- )
- parent_runtime = config[CONF].get(CONFIG_KEY_RUNTIME, DEFAULT_RUNTIME)
- runtime = parent_runtime.merge(runtime)
- config[CONF][CONFIG_KEY_RUNTIME] = runtime
- with SyncPregelLoop(
- input,
- stream=StreamProtocol(stream.put, stream_modes),
- config=config,
- store=store,
- cache=cache,
- checkpointer=checkpointer,
- nodes=self.nodes,
- specs=self.channels,
- output_keys=output_keys,
- input_keys=self.input_channels,
- stream_keys=self.stream_channels_asis,
- interrupt_before=interrupt_before_,
- interrupt_after=interrupt_after_,
- manager=run_manager,
- durability=durability_,
- trigger_to_nodes=self.trigger_to_nodes,
- migrate_checkpoint=self._migrate_checkpoint,
- retry_policy=self.retry_policy,
- cache_policy=self.cache_policy,
- ) as loop:
- # create runner
- runner = PregelRunner(
- submit=config[CONF].get(
- CONFIG_KEY_RUNNER_SUBMIT, weakref.WeakMethod(loop.submit)
- ),
- put_writes=weakref.WeakMethod(loop.put_writes),
- node_finished=config[CONF].get(CONFIG_KEY_NODE_FINISHED),
- )
- # enable subgraph streaming
- if subgraphs:
- loop.config[CONF][CONFIG_KEY_STREAM] = loop.stream
- # enable concurrent streaming
- get_waiter: Callable[[], concurrent.futures.Future[None]] | None = None
- if (
- self.stream_eager
- or subgraphs
- or "messages" in stream_modes
- or "custom" in stream_modes
- ):
- # we are careful to have a single waiter live at any one time
- # because on exit we increment semaphore count by exactly 1
- waiter: concurrent.futures.Future | None = None
- # because sync futures cannot be cancelled, we instead
- # release the stream semaphore on exit, which will cause
- # a pending waiter to return immediately
- loop.stack.callback(stream._count.release)
- def get_waiter() -> concurrent.futures.Future[None]:
- nonlocal waiter
- if waiter is None or waiter.done():
- waiter = loop.submit(stream.wait)
- return waiter
- else:
- return waiter
- # Similarly to Bulk Synchronous Parallel / Pregel model
- # computation proceeds in steps, while there are channel updates.
- # Channel updates from step N are only visible in step N+1
- # channels are guaranteed to be immutable for the duration of the step,
- # with channel updates applied only at the transition between steps.
- while loop.tick():
- for task in loop.match_cached_writes():
- loop.output_writes(task.id, task.writes, cached=True)
- for _ in runner.tick(
- [t for t in loop.tasks.values() if not t.writes],
- timeout=self.step_timeout,
- get_waiter=get_waiter,
- schedule_task=loop.accept_push,
- ):
- # emit output
- yield from _output(
- stream_mode, print_mode, subgraphs, stream.get, queue.Empty
- )
- loop.after_tick()
- # wait for checkpoint
- if durability_ == "sync":
- loop._put_checkpoint_fut.result()
- # emit output
- yield from _output(
- stream_mode, print_mode, subgraphs, stream.get, queue.Empty
- )
- # handle exit
- if loop.status == "out_of_steps":
- msg = create_error_message(
- message=(
- f"Recursion limit of {config['recursion_limit']} reached "
- "without hitting a stop condition. You can increase the "
- "limit by setting the `recursion_limit` config key."
- ),
- error_code=ErrorCode.GRAPH_RECURSION_LIMIT,
- )
- raise GraphRecursionError(msg)
- # set final channel values as run output
- run_manager.on_chain_end(loop.output)
- except BaseException as e:
- run_manager.on_chain_error(e)
- raise
- async def astream(
- self,
- input: InputT | Command | None,
- config: RunnableConfig | None = None,
- *,
- context: ContextT | None = None,
- stream_mode: StreamMode | Sequence[StreamMode] | None = None,
- print_mode: StreamMode | Sequence[StreamMode] = (),
- output_keys: str | Sequence[str] | None = None,
- interrupt_before: All | Sequence[str] | None = None,
- interrupt_after: All | Sequence[str] | None = None,
- durability: Durability | None = None,
- subgraphs: bool = False,
- debug: bool | None = None,
- **kwargs: Unpack[DeprecatedKwargs],
- ) -> AsyncIterator[dict[str, Any] | Any]:
- """Asynchronously stream graph steps for a single input.
- Args:
- input: The input to the graph.
- config: The configuration to use for the run.
- context: The static context to use for the run.
- !!! version-added "Added in version 0.6.0"
- stream_mode: The mode to stream output, defaults to `self.stream_mode`.
- Options are:
- - `"values"`: Emit all values in the state after each step, including interrupts.
- When used with functional API, values are emitted once at the end of the workflow.
- - `"updates"`: Emit only the node or task names and updates returned by the nodes or tasks after each step.
- If multiple updates are made in the same step (e.g. multiple nodes are run) then those updates are emitted separately.
- - `"custom"`: Emit custom data from inside nodes or tasks using `StreamWriter`.
- - `"messages"`: Emit LLM messages token-by-token together with metadata for any LLM invocations inside nodes or tasks.
- - Will be emitted as 2-tuples `(LLM token, metadata)`.
- - `"checkpoints"`: Emit an event when a checkpoint is created, in the same format as returned by `get_state()`.
- - `"tasks"`: Emit events when tasks start and finish, including their results and errors.
- - `"debug"`: Emit debug events with as much information as possible for each step.
- You can pass a list as the `stream_mode` parameter to stream multiple modes at once.
- The streamed outputs will be tuples of `(mode, data)`.
- See [LangGraph streaming guide](https://docs.langchain.com/oss/python/langgraph/streaming) for more details.
- print_mode: Accepts the same values as `stream_mode`, but only prints the output to the console, for debugging purposes.
- Does not affect the output of the graph in any way.
- output_keys: The keys to stream, defaults to all non-context channels.
- interrupt_before: Nodes to interrupt before, defaults to all nodes in the graph.
- interrupt_after: Nodes to interrupt after, defaults to all nodes in the graph.
- durability: The durability mode for the graph execution, defaults to `"async"`.
- Options are:
- - `"sync"`: Changes are persisted synchronously before the next step starts.
- - `"async"`: Changes are persisted asynchronously while the next step executes.
- - `"exit"`: Changes are persisted only when the graph exits.
- subgraphs: Whether to stream events from inside subgraphs, defaults to `False`.
- If `True`, the events will be emitted as tuples `(namespace, data)`,
- or `(namespace, mode, data)` if `stream_mode` is a list,
- where `namespace` is a tuple with the path to the node where a subgraph is invoked,
- e.g. `("parent_node:<task_id>", "child_node:<task_id>")`.
- See [LangGraph streaming guide](https://docs.langchain.com/oss/python/langgraph/streaming) for more details.
- Yields:
- The output of each step in the graph. The output shape depends on the `stream_mode`.
- """
- if (checkpoint_during := kwargs.get("checkpoint_during")) is not None:
- warnings.warn(
- "`checkpoint_during` is deprecated and will be removed. Please use `durability` instead.",
- category=LangGraphDeprecatedSinceV10,
- stacklevel=2,
- )
- if durability is not None:
- raise ValueError(
- "Cannot use both `checkpoint_during` and `durability` parameters. Please use `durability` instead."
- )
- durability = "async" if checkpoint_during else "exit"
- if stream_mode is None:
- # if being called as a node in another graph, default to values mode
- # but don't overwrite stream_mode arg if provided
- stream_mode = (
- "values"
- if config is not None and CONFIG_KEY_TASK_ID in config.get(CONF, {})
- else self.stream_mode
- )
- if debug or self.debug:
- print_mode = ["updates", "values"]
- stream = AsyncQueue()
- aioloop = asyncio.get_running_loop()
- stream_put = cast(
- Callable[[StreamChunk], None],
- partial(aioloop.call_soon_threadsafe, stream.put_nowait),
- )
- config = ensure_config(self.config, config)
- callback_manager = get_async_callback_manager_for_config(config)
- run_manager = await callback_manager.on_chain_start(
- None,
- input,
- name=config.get("run_name", self.get_name()),
- run_id=config.get("run_id"),
- )
- # if running from astream_log() run each proc with streaming
- do_stream = (
- next(
- (
- True
- for h in run_manager.handlers
- if isinstance(h, _StreamingCallbackHandler)
- and not isinstance(h, StreamMessagesHandler)
- ),
- False,
- )
- if _StreamingCallbackHandler is not None
- else False
- )
- try:
- # assign defaults
- (
- stream_modes,
- output_keys,
- interrupt_before_,
- interrupt_after_,
- checkpointer,
- store,
- cache,
- durability_,
- ) = self._defaults(
- config,
- stream_mode=stream_mode,
- print_mode=print_mode,
- output_keys=output_keys,
- interrupt_before=interrupt_before,
- interrupt_after=interrupt_after,
- durability=durability,
- )
- if checkpointer is None and durability is not None:
- warnings.warn(
- "`durability` has no effect when no checkpointer is present.",
- )
- # set up subgraph checkpointing
- if self.checkpointer is True:
- ns = cast(str, config[CONF][CONFIG_KEY_CHECKPOINT_NS])
- config[CONF][CONFIG_KEY_CHECKPOINT_NS] = recast_checkpoint_ns(ns)
- # set up messages stream mode
- if "messages" in stream_modes:
- # namespace can be None in a root level graph?
- ns_ = cast(str | None, config[CONF].get(CONFIG_KEY_CHECKPOINT_NS))
- run_manager.inheritable_handlers.append(
- StreamMessagesHandler(
- stream_put,
- subgraphs,
- parent_ns=tuple(ns_.split(NS_SEP)) if ns_ else None,
- )
- )
- # set up custom stream mode
- def stream_writer(c: Any) -> None:
- aioloop.call_soon_threadsafe(
- stream.put_nowait,
- (
- tuple(
- get_config()[CONF][CONFIG_KEY_CHECKPOINT_NS].split(NS_SEP)[
- :-1
- ]
- ),
- "custom",
- c,
- ),
- )
- if "custom" in stream_modes:
- def stream_writer(c: Any) -> None:
- aioloop.call_soon_threadsafe(
- stream.put_nowait,
- (
- tuple(
- get_config()[CONF][CONFIG_KEY_CHECKPOINT_NS].split(
- NS_SEP
- )[:-1]
- ),
- "custom",
- c,
- ),
- )
- elif CONFIG_KEY_STREAM in config[CONF]:
- stream_writer = config[CONF][CONFIG_KEY_RUNTIME].stream_writer
- else:
- def stream_writer(c: Any) -> None:
- pass
- # set durability mode for subgraphs
- if durability is not None:
- config[CONF][CONFIG_KEY_DURABILITY] = durability_
- runtime = Runtime(
- context=_coerce_context(self.context_schema, context),
- store=store,
- stream_writer=stream_writer,
- previous=None,
- )
- parent_runtime = config[CONF].get(CONFIG_KEY_RUNTIME, DEFAULT_RUNTIME)
- runtime = parent_runtime.merge(runtime)
- config[CONF][CONFIG_KEY_RUNTIME] = runtime
- async with AsyncPregelLoop(
- input,
- stream=StreamProtocol(stream.put_nowait, stream_modes),
- config=config,
- store=store,
- cache=cache,
- checkpointer=checkpointer,
- nodes=self.nodes,
- specs=self.channels,
- output_keys=output_keys,
- input_keys=self.input_channels,
- stream_keys=self.stream_channels_asis,
- interrupt_before=interrupt_before_,
- interrupt_after=interrupt_after_,
- manager=run_manager,
- durability=durability_,
- trigger_to_nodes=self.trigger_to_nodes,
- migrate_checkpoint=self._migrate_checkpoint,
- retry_policy=self.retry_policy,
- cache_policy=self.cache_policy,
- ) as loop:
- # create runner
- runner = PregelRunner(
- submit=config[CONF].get(
- CONFIG_KEY_RUNNER_SUBMIT, weakref.WeakMethod(loop.submit)
- ),
- put_writes=weakref.WeakMethod(loop.put_writes),
- use_astream=do_stream,
- node_finished=config[CONF].get(CONFIG_KEY_NODE_FINISHED),
- )
- # enable subgraph streaming
- if subgraphs:
- loop.config[CONF][CONFIG_KEY_STREAM] = StreamProtocol(
- stream_put, stream_modes
- )
- # enable concurrent streaming
- get_waiter: Callable[[], asyncio.Task[None]] | None = None
- _cleanup_waiter: Callable[[], Awaitable[None]] | None = None
- if (
- self.stream_eager
- or subgraphs
- or "messages" in stream_modes
- or "custom" in stream_modes
- ):
- # Keep a single waiter task alive; ensure cleanup on exit.
- waiter: asyncio.Task[None] | None = None
- def get_waiter() -> asyncio.Task[None]:
- nonlocal waiter
- if waiter is None or waiter.done():
- waiter = aioloop.create_task(stream.wait())
- def _clear(t: asyncio.Task[None]) -> None:
- nonlocal waiter
- if waiter is t:
- waiter = None
- waiter.add_done_callback(_clear)
- return waiter
- async def _cleanup_waiter() -> None:
- """Wake pending waiter and/or cancel+await to avoid pending tasks."""
- nonlocal waiter
- # Try to wake via semaphore like SyncPregelLoop
- with contextlib.suppress(Exception):
- if hasattr(stream, "_count"):
- stream._count.release()
- t = waiter
- waiter = None
- if t is not None and not t.done():
- t.cancel()
- with contextlib.suppress(asyncio.CancelledError):
- await t
- # Similarly to Bulk Synchronous Parallel / Pregel model
- # computation proceeds in steps, while there are channel updates
- # channel updates from step N are only visible in step N+1
- # channels are guaranteed to be immutable for the duration of the step,
- # with channel updates applied only at the transition between steps
- try:
- while loop.tick():
- for task in await loop.amatch_cached_writes():
- loop.output_writes(task.id, task.writes, cached=True)
- async for _ in runner.atick(
- [t for t in loop.tasks.values() if not t.writes],
- timeout=self.step_timeout,
- get_waiter=get_waiter,
- schedule_task=loop.aaccept_push,
- ):
- # emit output
- for o in _output(
- stream_mode,
- print_mode,
- subgraphs,
- stream.get_nowait,
- asyncio.QueueEmpty,
- ):
- yield o
- loop.after_tick()
- # wait for checkpoint
- if durability_ == "sync":
- await cast(asyncio.Future, loop._put_checkpoint_fut)
- finally:
- # ensure waiter doesn't remain pending on cancel/shutdown
- if _cleanup_waiter is not None:
- await _cleanup_waiter()
- # emit output
- for o in _output(
- stream_mode,
- print_mode,
- subgraphs,
- stream.get_nowait,
- asyncio.QueueEmpty,
- ):
- yield o
- # handle exit
- if loop.status == "out_of_steps":
- msg = create_error_message(
- message=(
- f"Recursion limit of {config['recursion_limit']} reached "
- "without hitting a stop condition. You can increase the "
- "limit by setting the `recursion_limit` config key."
- ),
- error_code=ErrorCode.GRAPH_RECURSION_LIMIT,
- )
- raise GraphRecursionError(msg)
- # set final channel values as run output
- await run_manager.on_chain_end(loop.output)
- except BaseException as e:
- await asyncio.shield(run_manager.on_chain_error(e))
- raise
- def invoke(
- self,
- input: InputT | Command | None,
- config: RunnableConfig | None = None,
- *,
- context: ContextT | None = None,
- stream_mode: StreamMode = "values",
- print_mode: StreamMode | Sequence[StreamMode] = (),
- output_keys: str | Sequence[str] | None = None,
- interrupt_before: All | Sequence[str] | None = None,
- interrupt_after: All | Sequence[str] | None = None,
- durability: Durability | None = None,
- **kwargs: Any,
- ) -> dict[str, Any] | Any:
- """Run the graph with a single input and config.
- Args:
- input: The input data for the graph. It can be a dictionary or any other type.
- config: The configuration for the graph run.
- context: The static context to use for the run.
- !!! version-added "Added in version 0.6.0"
- stream_mode: The stream mode for the graph run.
- print_mode: Accepts the same values as `stream_mode`, but only prints the output to the console, for debugging purposes.
- Does not affect the output of the graph in any way.
- output_keys: The output keys to retrieve from the graph run.
- interrupt_before: The nodes to interrupt the graph run before.
- interrupt_after: The nodes to interrupt the graph run after.
- durability: The durability mode for the graph execution, defaults to `"async"`.
- Options are:
- - `"sync"`: Changes are persisted synchronously before the next step starts.
- - `"async"`: Changes are persisted asynchronously while the next step executes.
- - `"exit"`: Changes are persisted only when the graph exits.
- **kwargs: Additional keyword arguments to pass to the graph run.
- Returns:
- The output of the graph run. If `stream_mode` is `"values"`, it returns the latest output.
- If `stream_mode` is not `"values"`, it returns a list of output chunks.
- """
- output_keys = output_keys if output_keys is not None else self.output_channels
- latest: dict[str, Any] | Any = None
- chunks: list[dict[str, Any] | Any] = []
- interrupts: list[Interrupt] = []
- for chunk in self.stream(
- input,
- config,
- context=context,
- stream_mode=["updates", "values"]
- if stream_mode == "values"
- else stream_mode,
- print_mode=print_mode,
- output_keys=output_keys,
- interrupt_before=interrupt_before,
- interrupt_after=interrupt_after,
- durability=durability,
- **kwargs,
- ):
- if stream_mode == "values":
- if len(chunk) == 2:
- mode, payload = cast(tuple[StreamMode, Any], chunk)
- else:
- _, mode, payload = cast(
- tuple[tuple[str, ...], StreamMode, Any], chunk
- )
- if (
- mode == "updates"
- and isinstance(payload, dict)
- and (ints := payload.get(INTERRUPT)) is not None
- ):
- interrupts.extend(ints)
- elif mode == "values":
- latest = payload
- else:
- chunks.append(chunk)
- if stream_mode == "values":
- if interrupts:
- return (
- {**latest, INTERRUPT: interrupts}
- if isinstance(latest, dict)
- else {INTERRUPT: interrupts}
- )
- return latest
- else:
- return chunks
- async def ainvoke(
- self,
- input: InputT | Command | None,
- config: RunnableConfig | None = None,
- *,
- context: ContextT | None = None,
- stream_mode: StreamMode = "values",
- print_mode: StreamMode | Sequence[StreamMode] = (),
- output_keys: str | Sequence[str] | None = None,
- interrupt_before: All | Sequence[str] | None = None,
- interrupt_after: All | Sequence[str] | None = None,
- durability: Durability | None = None,
- **kwargs: Any,
- ) -> dict[str, Any] | Any:
- """Asynchronously run the graph with a single input and config.
- Args:
- input: The input data for the graph. It can be a dictionary or any other type.
- config: The configuration for the graph run.
- context: The static context to use for the run.
- !!! version-added "Added in version 0.6.0"
- stream_mode: The stream mode for the graph run.
- print_mode: Accepts the same values as `stream_mode`, but only prints the output to the console, for debugging purposes.
- Does not affect the output of the graph in any way.
- output_keys: The output keys to retrieve from the graph run.
- interrupt_before: The nodes to interrupt the graph run before.
- interrupt_after: The nodes to interrupt the graph run after.
- durability: The durability mode for the graph execution, defaults to `"async"`.
- Options are:
- - `"sync"`: Changes are persisted synchronously before the next step starts.
- - `"async"`: Changes are persisted asynchronously while the next step executes.
- - `"exit"`: Changes are persisted only when the graph exits.
- **kwargs: Additional keyword arguments to pass to the graph run.
- Returns:
- The output of the graph run. If `stream_mode` is `"values"`, it returns the latest output.
- If `stream_mode` is not `"values"`, it returns a list of output chunks.
- """
- output_keys = output_keys if output_keys is not None else self.output_channels
- latest: dict[str, Any] | Any = None
- chunks: list[dict[str, Any] | Any] = []
- interrupts: list[Interrupt] = []
- async for chunk in self.astream(
- input,
- config,
- context=context,
- stream_mode=["updates", "values"]
- if stream_mode == "values"
- else stream_mode,
- print_mode=print_mode,
- output_keys=output_keys,
- interrupt_before=interrupt_before,
- interrupt_after=interrupt_after,
- durability=durability,
- **kwargs,
- ):
- if stream_mode == "values":
- if len(chunk) == 2:
- mode, payload = cast(tuple[StreamMode, Any], chunk)
- else:
- _, mode, payload = cast(
- tuple[tuple[str, ...], StreamMode, Any], chunk
- )
- if (
- mode == "updates"
- and isinstance(payload, dict)
- and (ints := payload.get(INTERRUPT)) is not None
- ):
- interrupts.extend(ints)
- elif mode == "values":
- latest = payload
- else:
- chunks.append(chunk)
- if stream_mode == "values":
- if interrupts:
- return (
- {**latest, INTERRUPT: interrupts}
- if isinstance(latest, dict)
- else {INTERRUPT: interrupts}
- )
- return latest
- else:
- return chunks
- def clear_cache(self, nodes: Sequence[str] | None = None) -> None:
- """Clear the cache for the given nodes."""
- if not self.cache:
- raise ValueError("No cache is set for this graph. Cannot clear cache.")
- nodes = nodes or self.nodes.keys()
- # collect namespaces to clear
- namespaces: list[tuple[str, ...]] = []
- for node in nodes:
- if node in self.nodes:
- namespaces.append(
- (
- CACHE_NS_WRITES,
- (identifier(self.nodes[node]) or "__dynamic__"),
- node,
- ),
- )
- # clear cache
- self.cache.clear(namespaces)
- async def aclear_cache(self, nodes: Sequence[str] | None = None) -> None:
- """Asynchronously clear the cache for the given nodes."""
- if not self.cache:
- raise ValueError("No cache is set for this graph. Cannot clear cache.")
- nodes = nodes or self.nodes.keys()
- # collect namespaces to clear
- namespaces: list[tuple[str, ...]] = []
- for node in nodes:
- if node in self.nodes:
- namespaces.append(
- (
- CACHE_NS_WRITES,
- (identifier(self.nodes[node]) or "__dynamic__"),
- node,
- ),
- )
- # clear cache
- await self.cache.aclear(namespaces)
- def _trigger_to_nodes(nodes: dict[str, PregelNode]) -> Mapping[str, Sequence[str]]:
- """Index from a trigger to nodes that depend on it."""
- trigger_to_nodes: defaultdict[str, list[str]] = defaultdict(list)
- for name, node in nodes.items():
- for trigger in node.triggers:
- trigger_to_nodes[trigger].append(name)
- return dict(trigger_to_nodes)
- def _output(
- stream_mode: StreamMode | Sequence[StreamMode],
- print_mode: StreamMode | Sequence[StreamMode],
- stream_subgraphs: bool,
- getter: Callable[[], tuple[tuple[str, ...], str, Any]],
- empty_exc: type[Exception],
- ) -> Iterator:
- while True:
- try:
- ns, mode, payload = getter()
- except empty_exc:
- break
- if mode in print_mode:
- if stream_subgraphs and ns:
- print(
- " ".join(
- (
- get_bolded_text(f"[{mode}]"),
- get_colored_text(f"[graph={ns}]", color="yellow"),
- repr(payload),
- )
- )
- )
- else:
- print(
- " ".join(
- (
- get_bolded_text(f"[{mode}]"),
- repr(payload),
- )
- )
- )
- if mode in stream_mode:
- if stream_subgraphs and isinstance(stream_mode, list):
- yield (ns, mode, payload)
- elif isinstance(stream_mode, list):
- yield (mode, payload)
- elif stream_subgraphs:
- yield (ns, payload)
- else:
- yield payload
- def _coerce_context(
- context_schema: type[ContextT] | None, context: Any
- ) -> ContextT | None:
- """Coerce context input to the appropriate schema type.
- If context is a dict and context_schema is a dataclass or pydantic model, we coerce.
- Else, we return the context as-is.
- Args:
- context_schema: The schema type to coerce to (BaseModel, dataclass, or TypedDict)
- context: The context value to coerce
- Returns:
- The coerced context value or None if context is None
- """
- if context is None:
- return None
- if context_schema is None:
- return context
- schema_is_class = issubclass(context_schema, BaseModel) or is_dataclass(
- context_schema
- )
- if isinstance(context, dict) and schema_is_class:
- return context_schema(**context) # type: ignore[misc]
- return cast(ContextT, context)
|