| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328 |
- from __future__ import annotations
- import asyncio
- import binascii
- import concurrent.futures
- from collections import defaultdict, deque
- from collections.abc import Callable, Iterator, Mapping, Sequence
- from contextlib import (
- AbstractAsyncContextManager,
- AbstractContextManager,
- AsyncExitStack,
- ExitStack,
- )
- from datetime import datetime, timezone
- from inspect import signature
- from types import TracebackType
- from typing import (
- Any,
- Literal,
- TypeVar,
- cast,
- )
- from langchain_core.callbacks import AsyncParentRunManager, ParentRunManager
- from langchain_core.runnables import RunnableConfig
- from langgraph.cache.base import BaseCache
- from langgraph.checkpoint.base import (
- WRITES_IDX_MAP,
- BaseCheckpointSaver,
- ChannelVersions,
- Checkpoint,
- CheckpointMetadata,
- CheckpointTuple,
- PendingWrite,
- )
- from langgraph.store.base import BaseStore
- from typing_extensions import ParamSpec, Self
- from langgraph._internal._config import patch_configurable
- from langgraph._internal._constants import (
- CONF,
- CONFIG_KEY_CHECKPOINT_ID,
- CONFIG_KEY_CHECKPOINT_MAP,
- CONFIG_KEY_CHECKPOINT_NS,
- CONFIG_KEY_RESUME_MAP,
- CONFIG_KEY_RESUMING,
- CONFIG_KEY_SCRATCHPAD,
- CONFIG_KEY_STREAM,
- CONFIG_KEY_TASK_ID,
- CONFIG_KEY_THREAD_ID,
- ERROR,
- INPUT,
- INTERRUPT,
- NS_END,
- NS_SEP,
- NULL_TASK_ID,
- PUSH,
- RESUME,
- TASKS,
- )
- from langgraph._internal._scratchpad import PregelScratchpad
- from langgraph._internal._typing import EMPTY_SEQ, MISSING
- from langgraph.channels.base import BaseChannel
- from langgraph.channels.untracked_value import UntrackedValue
- from langgraph.constants import TAG_HIDDEN
- from langgraph.errors import (
- EmptyInputError,
- GraphInterrupt,
- )
- from langgraph.managed.base import (
- ManagedValueMapping,
- ManagedValueSpec,
- )
- from langgraph.pregel._algo import (
- Call,
- GetNextVersion,
- PregelTaskWrites,
- apply_writes,
- checkpoint_null_version,
- increment,
- prepare_next_tasks,
- prepare_single_task,
- sanitize_untracked_values_in_send,
- should_interrupt,
- task_path_str,
- )
- from langgraph.pregel._checkpoint import (
- channels_from_checkpoint,
- copy_checkpoint,
- create_checkpoint,
- empty_checkpoint,
- )
- from langgraph.pregel._executor import (
- AsyncBackgroundExecutor,
- BackgroundExecutor,
- Submit,
- )
- from langgraph.pregel._io import (
- map_command,
- map_input,
- map_output_updates,
- map_output_values,
- read_channels,
- )
- from langgraph.pregel._read import PregelNode
- from langgraph.pregel._utils import get_new_channel_versions, is_xxh3_128_hexdigest
- from langgraph.pregel.debug import (
- map_debug_checkpoint,
- map_debug_task_results,
- map_debug_tasks,
- )
- from langgraph.pregel.protocol import StreamChunk, StreamProtocol
- from langgraph.types import (
- All,
- CachePolicy,
- Command,
- Durability,
- PregelExecutableTask,
- RetryPolicy,
- Send,
- StreamMode,
- )
- V = TypeVar("V")
- P = ParamSpec("P")
- WritesT = Sequence[tuple[str, Any]]
- def DuplexStream(*streams: StreamProtocol) -> StreamProtocol:
- def __call__(value: StreamChunk) -> None:
- for stream in streams:
- if value[1] in stream.modes:
- stream(value)
- return StreamProtocol(__call__, {mode for s in streams for mode in s.modes})
- class PregelLoop:
- config: RunnableConfig
- store: BaseStore | None
- stream: StreamProtocol | None
- step: int
- stop: int
- input: Any | None
- cache: BaseCache[WritesT] | None
- checkpointer: BaseCheckpointSaver | None
- nodes: Mapping[str, PregelNode]
- specs: Mapping[str, BaseChannel | ManagedValueSpec]
- input_keys: str | Sequence[str]
- output_keys: str | Sequence[str]
- stream_keys: str | Sequence[str]
- skip_done_tasks: bool
- is_nested: bool
- manager: None | AsyncParentRunManager | ParentRunManager
- interrupt_after: All | Sequence[str]
- interrupt_before: All | Sequence[str]
- durability: Durability
- retry_policy: Sequence[RetryPolicy]
- cache_policy: CachePolicy | None
- checkpointer_get_next_version: GetNextVersion
- checkpointer_put_writes: Callable[[RunnableConfig, WritesT, str], Any] | None
- checkpointer_put_writes_accepts_task_path: bool
- _checkpointer_put_after_previous: (
- Callable[
- [
- concurrent.futures.Future | None,
- RunnableConfig,
- Checkpoint,
- str,
- ChannelVersions,
- ],
- Any,
- ]
- | None
- )
- _migrate_checkpoint: Callable[[Checkpoint], None] | None
- submit: Submit
- channels: Mapping[str, BaseChannel]
- managed: ManagedValueMapping
- checkpoint: Checkpoint
- checkpoint_id_saved: str
- checkpoint_ns: tuple[str, ...]
- checkpoint_config: RunnableConfig
- checkpoint_metadata: CheckpointMetadata
- checkpoint_pending_writes: list[PendingWrite]
- checkpoint_previous_versions: dict[str, str | float | int]
- prev_checkpoint_config: RunnableConfig | None
- status: Literal[
- "input",
- "pending",
- "done",
- "interrupt_before",
- "interrupt_after",
- "out_of_steps",
- ]
- tasks: dict[str, PregelExecutableTask]
- output: None | dict[str, Any] | Any = None
- updated_channels: set[str] | None = None
- # public
- def __init__(
- self,
- input: Any | None,
- *,
- stream: StreamProtocol | None,
- config: RunnableConfig,
- store: BaseStore | None,
- cache: BaseCache | None,
- checkpointer: BaseCheckpointSaver | None,
- nodes: Mapping[str, PregelNode],
- specs: Mapping[str, BaseChannel | ManagedValueSpec],
- input_keys: str | Sequence[str],
- output_keys: str | Sequence[str],
- stream_keys: str | Sequence[str],
- trigger_to_nodes: Mapping[str, Sequence[str]],
- durability: Durability,
- interrupt_after: All | Sequence[str] = EMPTY_SEQ,
- interrupt_before: All | Sequence[str] = EMPTY_SEQ,
- manager: None | AsyncParentRunManager | ParentRunManager = None,
- migrate_checkpoint: Callable[[Checkpoint], None] | None = None,
- retry_policy: Sequence[RetryPolicy] = (),
- cache_policy: CachePolicy | None = None,
- ) -> None:
- self.stream = stream
- self.config = config
- self.store = store
- self.step = 0
- self.stop = 0
- self.input = input
- self.checkpointer = checkpointer
- self.cache = cache
- self.nodes = nodes
- self.specs = specs
- self.input_keys = input_keys
- self.output_keys = output_keys
- self.stream_keys = stream_keys
- self.interrupt_after = interrupt_after
- self.interrupt_before = interrupt_before
- self.manager = manager
- self.is_nested = CONFIG_KEY_TASK_ID in self.config.get(CONF, {})
- self.skip_done_tasks = CONFIG_KEY_CHECKPOINT_ID not in config[CONF]
- self._migrate_checkpoint = migrate_checkpoint
- self.trigger_to_nodes = trigger_to_nodes
- self.retry_policy = retry_policy
- self.cache_policy = cache_policy
- self.durability = durability
- if self.stream is not None and CONFIG_KEY_STREAM in config[CONF]:
- self.stream = DuplexStream(self.stream, config[CONF][CONFIG_KEY_STREAM])
- scratchpad: PregelScratchpad | None = config[CONF].get(CONFIG_KEY_SCRATCHPAD)
- if isinstance(scratchpad, PregelScratchpad):
- # if count is > 0, append to checkpoint_ns
- # if count is 0, leave as is
- if cnt := scratchpad.subgraph_counter():
- self.config = patch_configurable(
- self.config,
- {
- CONFIG_KEY_CHECKPOINT_NS: NS_SEP.join(
- (
- config[CONF][CONFIG_KEY_CHECKPOINT_NS],
- str(cnt),
- )
- )
- },
- )
- if not self.is_nested and config[CONF].get(CONFIG_KEY_CHECKPOINT_NS):
- self.config = patch_configurable(
- self.config,
- {CONFIG_KEY_CHECKPOINT_NS: "", CONFIG_KEY_CHECKPOINT_ID: None},
- )
- if (
- CONFIG_KEY_CHECKPOINT_MAP in self.config[CONF]
- and self.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS)
- in self.config[CONF][CONFIG_KEY_CHECKPOINT_MAP]
- ):
- self.checkpoint_config = patch_configurable(
- self.config,
- {
- CONFIG_KEY_CHECKPOINT_ID: self.config[CONF][
- CONFIG_KEY_CHECKPOINT_MAP
- ][self.config[CONF][CONFIG_KEY_CHECKPOINT_NS]]
- },
- )
- else:
- self.checkpoint_config = self.config
- if thread_id := self.checkpoint_config[CONF].get(CONFIG_KEY_THREAD_ID):
- if not isinstance(thread_id, str):
- self.checkpoint_config = patch_configurable(
- self.checkpoint_config,
- {CONFIG_KEY_THREAD_ID: str(thread_id)},
- )
- self.checkpoint_ns = (
- tuple(cast(str, self.config[CONF][CONFIG_KEY_CHECKPOINT_NS]).split(NS_SEP))
- if self.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS)
- else ()
- )
- self.prev_checkpoint_config = None
- def put_writes(self, task_id: str, writes: WritesT) -> None:
- """Put writes for a task, to be read by the next tick."""
- if not writes:
- return
- # deduplicate writes to special channels, last write wins
- if all(w[0] in WRITES_IDX_MAP for w in writes):
- writes = list({w[0]: w for w in writes}.values())
- if task_id == NULL_TASK_ID:
- # writes for the null task are accumulated
- self.checkpoint_pending_writes = [
- w
- for w in self.checkpoint_pending_writes
- if w[0] != task_id or w[1] not in WRITES_IDX_MAP
- ]
- writes_to_save: WritesT = [
- w[1:] for w in self.checkpoint_pending_writes if w[0] == task_id
- ] + list(writes)
- else:
- # remove existing writes for this task
- self.checkpoint_pending_writes = [
- w for w in self.checkpoint_pending_writes if w[0] != task_id
- ]
- writes_to_save = writes
- # check if any writes are to an UntrackedValue channel
- if any(
- isinstance(channel, UntrackedValue) for channel in self.channels.values()
- ):
- # we do not persist untracked values in checkpoints
- writes_to_save = [
- # sanitize UntrackedValues that are nested within Send packets
- (
- (c, sanitize_untracked_values_in_send(v, self.channels))
- if c == TASKS and isinstance(v, Send)
- else (c, v)
- )
- for c, v in writes_to_save
- # dont persist UntrackedValue channel writes
- if not isinstance(self.specs.get(c), UntrackedValue)
- ]
- # save writes
- self.checkpoint_pending_writes.extend((task_id, c, v) for c, v in writes)
- if self.durability != "exit" and self.checkpointer_put_writes is not None:
- config = patch_configurable(
- self.checkpoint_config,
- {
- CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get(
- CONFIG_KEY_CHECKPOINT_NS, ""
- ),
- CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"],
- },
- )
- if self.checkpointer_put_writes_accepts_task_path:
- if hasattr(self, "tasks"):
- task = self.tasks.get(task_id)
- else:
- task = None
- self.submit(
- self.checkpointer_put_writes,
- config,
- writes_to_save,
- task_id,
- task_path_str(task.path) if task else "",
- )
- else:
- self.submit(
- self.checkpointer_put_writes,
- config,
- writes_to_save,
- task_id,
- )
- # output writes
- if hasattr(self, "tasks"):
- self.output_writes(task_id, writes)
- def _put_pending_writes(self) -> None:
- if self.checkpointer_put_writes is None:
- return
- if not self.checkpoint_pending_writes:
- return
- # patch config
- config = patch_configurable(
- self.checkpoint_config,
- {
- CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get(
- CONFIG_KEY_CHECKPOINT_NS, ""
- ),
- CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"],
- },
- )
- # group by task id
- by_task = defaultdict(list)
- for task_id, channel, value in self.checkpoint_pending_writes:
- by_task[task_id].append((channel, value))
- # submit writes to checkpointer
- for task_id, writes in by_task.items():
- if self.checkpointer_put_writes_accepts_task_path and hasattr(
- self, "tasks"
- ):
- task = self.tasks.get(task_id)
- self.submit(
- self.checkpointer_put_writes,
- config,
- writes,
- task_id,
- task_path_str(task.path) if task else "",
- )
- else:
- self.submit(
- self.checkpointer_put_writes,
- config,
- writes,
- task_id,
- )
- def accept_push(
- self, task: PregelExecutableTask, write_idx: int, call: Call | None = None
- ) -> PregelExecutableTask | None:
- """Accept a PUSH from a task, potentially returning a new task to start."""
- checkpoint_id_bytes = binascii.unhexlify(self.checkpoint["id"].replace("-", ""))
- null_version = checkpoint_null_version(self.checkpoint)
- if pushed := cast(
- PregelExecutableTask | None,
- prepare_single_task(
- (PUSH, task.path, write_idx, task.id, call),
- None,
- checkpoint=self.checkpoint,
- checkpoint_id_bytes=checkpoint_id_bytes,
- checkpoint_null_version=null_version,
- pending_writes=self.checkpoint_pending_writes,
- processes=self.nodes,
- channels=self.channels,
- managed=self.managed,
- config=task.config,
- step=self.step,
- stop=self.stop,
- for_execution=True,
- store=self.store,
- checkpointer=self.checkpointer,
- manager=self.manager,
- retry_policy=self.retry_policy,
- cache_policy=self.cache_policy,
- ),
- ):
- # produce debug output
- self._emit("tasks", map_debug_tasks, [pushed])
- # save the new task
- self.tasks[pushed.id] = pushed
- # match any pending writes to the new task
- if self.skip_done_tasks:
- self._match_writes({pushed.id: pushed})
- # return the new task, to be started if not run before
- return pushed
- def tick(self) -> bool:
- """Execute a single iteration of the Pregel loop.
- Returns:
- True if more iterations are needed.
- """
- # check if iteration limit is reached
- if self.step > self.stop:
- self.status = "out_of_steps"
- return False
- # prepare next tasks
- self.tasks = prepare_next_tasks(
- self.checkpoint,
- self.checkpoint_pending_writes,
- self.nodes,
- self.channels,
- self.managed,
- self.config,
- self.step,
- self.stop,
- for_execution=True,
- manager=self.manager,
- store=self.store,
- checkpointer=self.checkpointer,
- trigger_to_nodes=self.trigger_to_nodes,
- updated_channels=self.updated_channels,
- retry_policy=self.retry_policy,
- cache_policy=self.cache_policy,
- )
- # produce debug output
- if self._checkpointer_put_after_previous is not None:
- self._emit(
- "checkpoints",
- map_debug_checkpoint,
- {
- **self.checkpoint_config,
- CONF: {
- **self.checkpoint_config[CONF],
- CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"],
- },
- },
- self.channels,
- self.stream_keys,
- self.checkpoint_metadata,
- self.tasks.values(),
- self.checkpoint_pending_writes,
- self.prev_checkpoint_config,
- self.output_keys,
- )
- # if no more tasks, we're done
- if not self.tasks:
- self.status = "done"
- return False
- # if there are pending writes from a previous loop, apply them
- if self.skip_done_tasks and self.checkpoint_pending_writes:
- self._match_writes(self.tasks)
- # before execution, check if we should interrupt
- if self.interrupt_before and should_interrupt(
- self.checkpoint, self.interrupt_before, self.tasks.values()
- ):
- self.status = "interrupt_before"
- raise GraphInterrupt()
- # produce debug output
- self._emit("tasks", map_debug_tasks, self.tasks.values())
- # print output for any tasks we applied previous writes to
- for task in self.tasks.values():
- if task.writes:
- self.output_writes(task.id, task.writes, cached=True)
- return True
- def after_tick(self) -> None:
- # finish superstep
- writes = [w for t in self.tasks.values() for w in t.writes]
- # all tasks have finished
- self.updated_channels = apply_writes(
- self.checkpoint,
- self.channels,
- self.tasks.values(),
- self.checkpointer_get_next_version,
- self.trigger_to_nodes,
- )
- # produce values output
- if not self.updated_channels.isdisjoint(
- (self.output_keys,)
- if isinstance(self.output_keys, str)
- else self.output_keys
- ):
- self._emit(
- "values", map_output_values, self.output_keys, writes, self.channels
- )
- # clear pending writes
- self.checkpoint_pending_writes.clear()
- # "not skip_done_tasks" only applies to first tick after resuming
- self.skip_done_tasks = True
- # save checkpoint
- self._put_checkpoint({"source": "loop"})
- # after execution, check if we should interrupt
- if self.interrupt_after and should_interrupt(
- self.checkpoint, self.interrupt_after, self.tasks.values()
- ):
- self.status = "interrupt_after"
- raise GraphInterrupt()
- # unset resuming flag
- self.config[CONF].pop(CONFIG_KEY_RESUMING, None)
- def match_cached_writes(self) -> Sequence[PregelExecutableTask]:
- raise NotImplementedError
- async def amatch_cached_writes(self) -> Sequence[PregelExecutableTask]:
- raise NotImplementedError
- # private
- def _match_writes(self, tasks: Mapping[str, PregelExecutableTask]) -> None:
- for tid, k, v in self.checkpoint_pending_writes:
- if k in (ERROR, INTERRUPT, RESUME):
- continue
- if task := tasks.get(tid):
- task.writes.append((k, v))
- def _pending_interrupts(self) -> set[str]:
- """Return the set of interrupt ids that are pending without corresponding resume values."""
- # mapping of task ids to interrupt ids
- pending_interrupts: dict[str, str] = {}
- # set of resume task ids
- pending_resumes: set[str] = set()
- for task_id, write_type, value in self.checkpoint_pending_writes:
- if write_type == INTERRUPT:
- # interrupts is always a list, but there should only be one element
- pending_interrupts[task_id] = value[0].id
- elif write_type == RESUME:
- pending_resumes.add(task_id)
- resumed_interrupt_ids = {
- pending_interrupts[task_id]
- for task_id in pending_resumes
- if task_id in pending_interrupts
- }
- # Keep only interrupts whose interrupt_id is not resumed
- hanging_interrupts: set[str] = {
- interrupt_id
- for interrupt_id in pending_interrupts.values()
- if interrupt_id not in resumed_interrupt_ids
- }
- return hanging_interrupts
- def _first(
- self, *, input_keys: str | Sequence[str], updated_channels: set[str] | None
- ) -> set[str] | None:
- # resuming from previous checkpoint requires
- # - finding a previous checkpoint
- # - receiving None input (outer graph) or RESUMING flag (subgraph)
- configurable = self.config.get(CONF, {})
- is_resuming = bool(self.checkpoint["channel_versions"]) and bool(
- configurable.get(
- CONFIG_KEY_RESUMING,
- self.input is None
- or isinstance(self.input, Command)
- or (
- not self.is_nested
- and self.config.get("metadata", {}).get("run_id")
- == self.checkpoint_metadata.get("run_id", MISSING)
- ),
- )
- )
- # map command to writes
- if isinstance(self.input, Command):
- if (resume := self.input.resume) is not None:
- if not self.checkpointer:
- raise RuntimeError(
- "Cannot use Command(resume=...) without checkpointer"
- )
- if resume_is_map := (
- isinstance(resume, dict)
- and all(is_xxh3_128_hexdigest(k) for k in resume)
- ):
- self.config[CONF][CONFIG_KEY_RESUME_MAP] = resume
- else:
- if len(self._pending_interrupts()) > 1:
- raise RuntimeError(
- "When there are multiple pending interrupts, you must specify the interrupt id when resuming. "
- "Docs: https://docs.langchain.com/oss/python/langgraph/add-human-in-the-loop#resume-multiple-interrupts-with-one-invocation."
- )
- writes: defaultdict[str, list[tuple[str, Any]]] = defaultdict(list)
- # group writes by task ID
- for tid, c, v in map_command(cmd=self.input):
- if not (c == RESUME and resume_is_map):
- writes[tid].append((c, v))
- if not writes and not resume_is_map:
- raise EmptyInputError("Received empty Command input")
- # save writes
- for tid, ws in writes.items():
- self.put_writes(tid, ws)
- # apply NULL writes
- if null_writes := [
- w[1:] for w in self.checkpoint_pending_writes if w[0] == NULL_TASK_ID
- ]:
- null_updated_channels = apply_writes(
- self.checkpoint,
- self.channels,
- [PregelTaskWrites((), INPUT, null_writes, [])],
- self.checkpointer_get_next_version,
- self.trigger_to_nodes,
- )
- if updated_channels is not None:
- updated_channels.update(null_updated_channels)
- # proceed past previous checkpoint
- if is_resuming:
- self.checkpoint["versions_seen"].setdefault(INTERRUPT, {})
- for k in self.channels:
- if k in self.checkpoint["channel_versions"]:
- version = self.checkpoint["channel_versions"][k]
- self.checkpoint["versions_seen"][INTERRUPT][k] = version
- # produce values output
- self._emit(
- "values", map_output_values, self.output_keys, True, self.channels
- )
- # map inputs to channel updates
- elif input_writes := deque(map_input(input_keys, self.input)):
- # discard any unfinished tasks from previous checkpoint
- discard_tasks = prepare_next_tasks(
- self.checkpoint,
- self.checkpoint_pending_writes,
- self.nodes,
- self.channels,
- self.managed,
- self.config,
- self.step,
- self.stop,
- for_execution=True,
- store=None,
- checkpointer=None,
- manager=None,
- updated_channels=updated_channels,
- )
- # apply input writes
- updated_channels = apply_writes(
- self.checkpoint,
- self.channels,
- [
- *discard_tasks.values(),
- PregelTaskWrites((), INPUT, input_writes, []),
- ],
- self.checkpointer_get_next_version,
- self.trigger_to_nodes,
- )
- # save input checkpoint
- self.updated_channels = updated_channels
- self._put_checkpoint({"source": "input"})
- elif CONFIG_KEY_RESUMING not in configurable:
- raise EmptyInputError(f"Received no input for {input_keys}")
- # update config
- if not self.is_nested:
- self.config = patch_configurable(
- self.config, {CONFIG_KEY_RESUMING: is_resuming}
- )
- # set flag
- self.status = "pending"
- return updated_channels
- def _put_checkpoint(self, metadata: CheckpointMetadata) -> None:
- # assign step and parents
- exiting = metadata is self.checkpoint_metadata
- if exiting and self.checkpoint["id"] == self.checkpoint_id_saved:
- # checkpoint already saved
- return
- if not exiting:
- metadata["step"] = self.step
- metadata["parents"] = self.config[CONF].get(CONFIG_KEY_CHECKPOINT_MAP, {})
- self.checkpoint_metadata = metadata
- # do checkpoint?
- do_checkpoint = self._checkpointer_put_after_previous is not None and (
- exiting or self.durability != "exit"
- )
- # create new checkpoint
- self.checkpoint = create_checkpoint(
- self.checkpoint,
- self.channels if do_checkpoint else None,
- self.step,
- id=self.checkpoint["id"] if exiting else None,
- updated_channels=self.updated_channels,
- )
- # sanitize TASK channel in the checkpoint before saving (durability=="exit")
- if TASKS in self.checkpoint["channel_values"] and any(
- isinstance(channel, UntrackedValue) for channel in self.channels.values()
- ):
- sanitized_tasks = [
- sanitize_untracked_values_in_send(value, self.channels)
- if isinstance(value, Send)
- else value
- for value in self.checkpoint["channel_values"][TASKS]
- ]
- self.checkpoint["channel_values"][TASKS] = sanitized_tasks
- # bail if no checkpointer
- if do_checkpoint and self._checkpointer_put_after_previous is not None:
- self.prev_checkpoint_config = (
- self.checkpoint_config
- if CONFIG_KEY_CHECKPOINT_ID in self.checkpoint_config[CONF]
- and self.checkpoint_config[CONF][CONFIG_KEY_CHECKPOINT_ID]
- else None
- )
- self.checkpoint_config = {
- **self.checkpoint_config,
- CONF: {
- **self.checkpoint_config[CONF],
- CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get(
- CONFIG_KEY_CHECKPOINT_NS, ""
- ),
- },
- }
- channel_versions = self.checkpoint["channel_versions"].copy()
- new_versions = get_new_channel_versions(
- self.checkpoint_previous_versions, channel_versions
- )
- self.checkpoint_previous_versions = channel_versions
- # save it, without blocking
- # if there's a previous checkpoint save in progress, wait for it
- # ensuring checkpointers receive checkpoints in order
- self._put_checkpoint_fut = self.submit(
- self._checkpointer_put_after_previous,
- getattr(self, "_put_checkpoint_fut", None),
- self.checkpoint_config,
- copy_checkpoint(self.checkpoint),
- self.checkpoint_metadata,
- new_versions,
- )
- self.checkpoint_config = {
- **self.checkpoint_config,
- CONF: {
- **self.checkpoint_config[CONF],
- CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"],
- },
- }
- if not exiting:
- # increment step
- self.step += 1
- def _suppress_interrupt(
- self,
- exc_type: type[BaseException] | None,
- exc_value: BaseException | None,
- traceback: TracebackType | None,
- ) -> bool | None:
- # persist current checkpoint and writes
- if self.durability == "exit" and (
- # if it's a top graph
- not self.is_nested
- # or a nested graph with error or interrupt
- or exc_value is not None
- # or a nested graph with checkpointer=True
- or all(NS_END not in part for part in self.checkpoint_ns)
- ):
- self._put_checkpoint(self.checkpoint_metadata)
- self._put_pending_writes()
- # suppress interrupt
- suppress = isinstance(exc_value, GraphInterrupt) and not self.is_nested
- if suppress:
- # emit one last "values" event, with pending writes applied
- if (
- hasattr(self, "tasks")
- and self.checkpoint_pending_writes
- and any(task.writes for task in self.tasks.values())
- ):
- updated_channels = apply_writes(
- self.checkpoint,
- self.channels,
- self.tasks.values(),
- self.checkpointer_get_next_version,
- self.trigger_to_nodes,
- )
- if not updated_channels.isdisjoint(
- (self.output_keys,)
- if isinstance(self.output_keys, str)
- else self.output_keys
- ):
- self._emit(
- "values",
- map_output_values,
- self.output_keys,
- [w for t in self.tasks.values() for w in t.writes],
- self.channels,
- )
- # emit INTERRUPT if exception is empty (otherwise emitted by put_writes)
- if exc_value is not None and (not exc_value.args or not exc_value.args[0]):
- self._emit(
- "updates",
- lambda: iter(
- [{INTERRUPT: cast(GraphInterrupt, exc_value).args[0]}]
- ),
- )
- # save final output
- self.output = read_channels(self.channels, self.output_keys)
- # suppress interrupt
- return True
- elif exc_type is None:
- # save final output
- self.output = read_channels(self.channels, self.output_keys)
- def _emit(
- self,
- mode: StreamMode,
- values: Callable[P, Iterator[Any]],
- *args: P.args,
- **kwargs: P.kwargs,
- ) -> None:
- if self.stream is None:
- return
- debug_remap = mode in ("checkpoints", "tasks") and "debug" in self.stream.modes
- if mode not in self.stream.modes and not debug_remap:
- return
- for v in values(*args, **kwargs):
- if mode in self.stream.modes:
- self.stream((self.checkpoint_ns, mode, v))
- # "debug" mode is "checkpoints" or "tasks" with a wrapper dict
- if debug_remap:
- self.stream(
- (
- self.checkpoint_ns,
- "debug",
- {
- "step": self.step - 1
- if mode == "checkpoints"
- else self.step,
- "timestamp": datetime.now(timezone.utc).isoformat(),
- "type": "checkpoint"
- if mode == "checkpoints"
- else "task_result"
- if "result" in v
- else "task",
- "payload": v,
- },
- )
- )
- def output_writes(
- self, task_id: str, writes: WritesT, *, cached: bool = False
- ) -> None:
- if task := self.tasks.get(task_id):
- if task.config is not None and TAG_HIDDEN in task.config.get(
- "tags", EMPTY_SEQ
- ):
- return
- if writes[0][0] == INTERRUPT:
- # in loop.py we append a bool to the PUSH task paths to indicate
- # whether or not a call was present. If so,
- # we don't emit the interrupt as it'll be emitted by the parent
- if task.path[0] == PUSH and task.path[-1] is True:
- return
- interrupts = [
- {
- INTERRUPT: tuple(
- v
- for w in writes
- if w[0] == INTERRUPT
- for v in (w[1] if isinstance(w[1], Sequence) else (w[1],))
- )
- }
- ]
- stream_modes = self.stream.modes if self.stream else []
- if "updates" in stream_modes:
- self._emit("updates", lambda: iter(interrupts))
- if "values" in stream_modes:
- current_values = read_channels(self.channels, self.output_keys)
- # self.output_keys is a sequence, stream chunk contains entire state and interrupts
- if isinstance(current_values, dict):
- current_values[INTERRUPT] = interrupts[0][INTERRUPT]
- self._emit("values", lambda: iter([current_values]))
- # self.output_keys is a string, stream chunk contains only interrupts
- else:
- self._emit("values", lambda: iter(interrupts))
- elif writes[0][0] != ERROR:
- self._emit(
- "updates",
- map_output_updates,
- self.output_keys,
- [(task, writes)],
- cached,
- )
- if not cached:
- self._emit(
- "tasks",
- map_debug_task_results,
- (task, writes),
- self.stream_keys,
- )
- class SyncPregelLoop(PregelLoop, AbstractContextManager):
- def __init__(
- self,
- input: Any | None,
- *,
- stream: StreamProtocol | None,
- config: RunnableConfig,
- store: BaseStore | None,
- cache: BaseCache | None,
- checkpointer: BaseCheckpointSaver | None,
- nodes: Mapping[str, PregelNode],
- specs: Mapping[str, BaseChannel | ManagedValueSpec],
- trigger_to_nodes: Mapping[str, Sequence[str]],
- durability: Durability,
- manager: None | AsyncParentRunManager | ParentRunManager = None,
- interrupt_after: All | Sequence[str] = EMPTY_SEQ,
- interrupt_before: All | Sequence[str] = EMPTY_SEQ,
- input_keys: str | Sequence[str] = EMPTY_SEQ,
- output_keys: str | Sequence[str] = EMPTY_SEQ,
- stream_keys: str | Sequence[str] = EMPTY_SEQ,
- migrate_checkpoint: Callable[[Checkpoint], None] | None = None,
- retry_policy: Sequence[RetryPolicy] = (),
- cache_policy: CachePolicy | None = None,
- ) -> None:
- super().__init__(
- input,
- stream=stream,
- config=config,
- checkpointer=checkpointer,
- cache=cache,
- store=store,
- nodes=nodes,
- specs=specs,
- input_keys=input_keys,
- output_keys=output_keys,
- stream_keys=stream_keys,
- interrupt_after=interrupt_after,
- interrupt_before=interrupt_before,
- manager=manager,
- migrate_checkpoint=migrate_checkpoint,
- trigger_to_nodes=trigger_to_nodes,
- retry_policy=retry_policy,
- cache_policy=cache_policy,
- durability=durability,
- )
- self.stack = ExitStack()
- if checkpointer:
- self.checkpointer_get_next_version = checkpointer.get_next_version
- self.checkpointer_put_writes = checkpointer.put_writes
- self.checkpointer_put_writes_accepts_task_path = (
- signature(checkpointer.put_writes).parameters.get("task_path")
- is not None
- )
- else:
- self.checkpointer_get_next_version = increment
- self._checkpointer_put_after_previous = None # type: ignore[assignment]
- self.checkpointer_put_writes = None
- self.checkpointer_put_writes_accepts_task_path = False
- def _checkpointer_put_after_previous(
- self,
- prev: concurrent.futures.Future | None,
- config: RunnableConfig,
- checkpoint: Checkpoint,
- metadata: CheckpointMetadata,
- new_versions: ChannelVersions,
- ) -> RunnableConfig:
- try:
- if prev is not None:
- prev.result()
- finally:
- cast(BaseCheckpointSaver, self.checkpointer).put(
- config, checkpoint, metadata, new_versions
- )
- def match_cached_writes(self) -> Sequence[PregelExecutableTask]:
- if self.cache is None:
- return ()
- matched: list[PregelExecutableTask] = []
- if cached := {
- (t.cache_key.ns, t.cache_key.key): t
- for t in self.tasks.values()
- if t.cache_key and not t.writes
- }:
- for key, values in self.cache.get(tuple(cached)).items():
- task = cached[key]
- task.writes.extend(values)
- matched.append(task)
- return matched
- def accept_push(
- self, task: PregelExecutableTask, write_idx: int, call: Call | None = None
- ) -> PregelExecutableTask | None:
- if pushed := super().accept_push(task, write_idx, call):
- for task in self.match_cached_writes():
- self.output_writes(task.id, task.writes, cached=True)
- return pushed
- def put_writes(self, task_id: str, writes: WritesT) -> None:
- """Put writes for a task, to be read by the next tick."""
- super().put_writes(task_id, writes)
- if not writes or self.cache is None or not hasattr(self, "tasks"):
- return
- task = self.tasks.get(task_id)
- if task is None or task.cache_key is None:
- return
- self.submit(
- self.cache.set,
- {
- (task.cache_key.ns, task.cache_key.key): (
- task.writes,
- task.cache_key.ttl,
- )
- },
- )
- # context manager
- def __enter__(self) -> Self:
- if self.checkpointer:
- saved = self.checkpointer.get_tuple(self.checkpoint_config)
- else:
- saved = None
- if saved is None:
- saved = CheckpointTuple(
- self.checkpoint_config, empty_checkpoint(), {"step": -2}, None, []
- )
- elif self._migrate_checkpoint is not None:
- self._migrate_checkpoint(saved.checkpoint)
- self.checkpoint_config = {
- **self.checkpoint_config,
- **saved.config,
- CONF: {
- CONFIG_KEY_CHECKPOINT_NS: "",
- **self.checkpoint_config.get(CONF, {}),
- **saved.config.get(CONF, {}),
- },
- }
- self.prev_checkpoint_config = saved.parent_config
- self.checkpoint_id_saved = saved.checkpoint["id"]
- self.checkpoint = saved.checkpoint
- self.checkpoint_metadata = saved.metadata
- self.checkpoint_pending_writes = (
- [(str(tid), k, v) for tid, k, v in saved.pending_writes]
- if saved.pending_writes is not None
- else []
- )
- self.submit = self.stack.enter_context(BackgroundExecutor(self.config))
- self.channels, self.managed = channels_from_checkpoint(
- self.specs, self.checkpoint
- )
- self.stack.push(self._suppress_interrupt)
- self.status = "input"
- self.step = self.checkpoint_metadata["step"] + 1
- self.stop = self.step + self.config["recursion_limit"] + 1
- self.checkpoint_previous_versions = self.checkpoint["channel_versions"].copy()
- self.updated_channels = self._first(
- input_keys=self.input_keys,
- updated_channels=set(self.checkpoint.get("updated_channels")) # type: ignore[arg-type]
- if self.checkpoint.get("updated_channels")
- else None,
- )
- return self
- def __exit__(
- self,
- exc_type: type[BaseException] | None,
- exc_value: BaseException | None,
- traceback: TracebackType | None,
- ) -> bool | None:
- # unwind stack
- return self.stack.__exit__(exc_type, exc_value, traceback)
- class AsyncPregelLoop(PregelLoop, AbstractAsyncContextManager):
- def __init__(
- self,
- input: Any | None,
- *,
- stream: StreamProtocol | None,
- config: RunnableConfig,
- store: BaseStore | None,
- cache: BaseCache | None,
- checkpointer: BaseCheckpointSaver | None,
- nodes: Mapping[str, PregelNode],
- specs: Mapping[str, BaseChannel | ManagedValueSpec],
- trigger_to_nodes: Mapping[str, Sequence[str]],
- durability: Durability,
- interrupt_after: All | Sequence[str] = EMPTY_SEQ,
- interrupt_before: All | Sequence[str] = EMPTY_SEQ,
- manager: None | AsyncParentRunManager | ParentRunManager = None,
- input_keys: str | Sequence[str] = EMPTY_SEQ,
- output_keys: str | Sequence[str] = EMPTY_SEQ,
- stream_keys: str | Sequence[str] = EMPTY_SEQ,
- migrate_checkpoint: Callable[[Checkpoint], None] | None = None,
- retry_policy: Sequence[RetryPolicy] = (),
- cache_policy: CachePolicy | None = None,
- ) -> None:
- super().__init__(
- input,
- stream=stream,
- config=config,
- checkpointer=checkpointer,
- cache=cache,
- store=store,
- nodes=nodes,
- specs=specs,
- input_keys=input_keys,
- output_keys=output_keys,
- stream_keys=stream_keys,
- interrupt_after=interrupt_after,
- interrupt_before=interrupt_before,
- manager=manager,
- migrate_checkpoint=migrate_checkpoint,
- trigger_to_nodes=trigger_to_nodes,
- retry_policy=retry_policy,
- cache_policy=cache_policy,
- durability=durability,
- )
- self.stack = AsyncExitStack()
- if checkpointer:
- self.checkpointer_get_next_version = checkpointer.get_next_version
- self.checkpointer_put_writes = checkpointer.aput_writes
- self.checkpointer_put_writes_accepts_task_path = (
- signature(checkpointer.aput_writes).parameters.get("task_path")
- is not None
- )
- else:
- self.checkpointer_get_next_version = increment
- self._checkpointer_put_after_previous = None # type: ignore[assignment]
- self.checkpointer_put_writes = None
- self.checkpointer_put_writes_accepts_task_path = False
- async def _checkpointer_put_after_previous(
- self,
- prev: asyncio.Task | None,
- config: RunnableConfig,
- checkpoint: Checkpoint,
- metadata: CheckpointMetadata,
- new_versions: ChannelVersions,
- ) -> RunnableConfig:
- try:
- if prev is not None:
- await prev
- finally:
- await cast(BaseCheckpointSaver, self.checkpointer).aput(
- config, checkpoint, metadata, new_versions
- )
- async def amatch_cached_writes(self) -> Sequence[PregelExecutableTask]:
- if self.cache is None:
- return []
- matched: list[PregelExecutableTask] = []
- if cached := {
- (t.cache_key.ns, t.cache_key.key): t
- for t in self.tasks.values()
- if t.cache_key and not t.writes
- }:
- for key, values in (await self.cache.aget(tuple(cached))).items():
- task = cached[key]
- task.writes.extend(values)
- matched.append(task)
- return matched
- async def aaccept_push(
- self, task: PregelExecutableTask, write_idx: int, call: Call | None = None
- ) -> PregelExecutableTask | None:
- if pushed := super().accept_push(task, write_idx, call):
- for task in await self.amatch_cached_writes():
- self.output_writes(task.id, task.writes, cached=True)
- return pushed
- def put_writes(self, task_id: str, writes: WritesT) -> None:
- """Put writes for a task, to be read by the next tick."""
- super().put_writes(task_id, writes)
- if not writes or self.cache is None or not hasattr(self, "tasks"):
- return
- task = self.tasks.get(task_id)
- if task is None or task.cache_key is None:
- return
- if writes[0][0] in (INTERRUPT, ERROR):
- # only cache successful tasks
- return
- self.submit(
- self.cache.aset,
- {
- (task.cache_key.ns, task.cache_key.key): (
- task.writes,
- task.cache_key.ttl,
- )
- },
- )
- # context manager
- async def __aenter__(self) -> Self:
- if self.checkpointer:
- saved = await self.checkpointer.aget_tuple(self.checkpoint_config)
- else:
- saved = None
- if saved is None:
- saved = CheckpointTuple(
- self.checkpoint_config, empty_checkpoint(), {"step": -2}, None, []
- )
- elif self._migrate_checkpoint is not None:
- self._migrate_checkpoint(saved.checkpoint)
- self.checkpoint_config = {
- **self.checkpoint_config,
- **saved.config,
- CONF: {
- CONFIG_KEY_CHECKPOINT_NS: "",
- **self.checkpoint_config.get(CONF, {}),
- **saved.config.get(CONF, {}),
- },
- }
- self.prev_checkpoint_config = saved.parent_config
- self.checkpoint_id_saved = saved.checkpoint["id"]
- self.checkpoint = saved.checkpoint
- self.checkpoint_metadata = saved.metadata
- self.checkpoint_pending_writes = (
- [(str(tid), k, v) for tid, k, v in saved.pending_writes]
- if saved.pending_writes is not None
- else []
- )
- self.submit = await self.stack.enter_async_context(
- AsyncBackgroundExecutor(self.config)
- )
- self.channels, self.managed = channels_from_checkpoint(
- self.specs, self.checkpoint
- )
- self.stack.push(self._suppress_interrupt)
- self.status = "input"
- self.step = self.checkpoint_metadata["step"] + 1
- self.stop = self.step + self.config["recursion_limit"] + 1
- self.checkpoint_previous_versions = self.checkpoint["channel_versions"].copy()
- self.updated_channels = self._first(
- input_keys=self.input_keys,
- updated_channels=set(self.checkpoint.get("updated_channels")) # type: ignore[arg-type]
- if self.checkpoint.get("updated_channels")
- else None,
- )
- return self
- async def __aexit__(
- self,
- exc_type: type[BaseException] | None,
- exc_value: BaseException | None,
- traceback: TracebackType | None,
- ) -> bool | None:
- # unwind stack
- exit_task = asyncio.create_task(
- self.stack.__aexit__(exc_type, exc_value, traceback)
- )
- try:
- return await exit_task
- except asyncio.CancelledError as e:
- # Bubble up the exit task upon cancellation to permit the API
- # consumer to await it before e.g., reusing the DB connection.
- e.args = (*e.args, exit_task)
- raise
|