| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233 |
- from __future__ import annotations
- import binascii
- import itertools
- import sys
- import threading
- from collections import defaultdict, deque
- from collections.abc import Callable, Iterable, Mapping, Sequence
- from copy import copy
- from functools import partial
- from hashlib import sha1
- from typing import (
- Any,
- Literal,
- NamedTuple,
- Protocol,
- cast,
- overload,
- )
- from langchain_core.callbacks import Callbacks
- from langchain_core.callbacks.manager import AsyncParentRunManager, ParentRunManager
- from langchain_core.runnables.config import RunnableConfig
- from langgraph.checkpoint.base import (
- BaseCheckpointSaver,
- ChannelVersions,
- Checkpoint,
- PendingWrite,
- V,
- )
- from langgraph.store.base import BaseStore
- from xxhash import xxh3_128_hexdigest
- from langgraph._internal._config import merge_configs, patch_config
- from langgraph._internal._constants import (
- CACHE_NS_WRITES,
- CONF,
- CONFIG_KEY_CHECKPOINT_ID,
- CONFIG_KEY_CHECKPOINT_MAP,
- CONFIG_KEY_CHECKPOINT_NS,
- CONFIG_KEY_CHECKPOINTER,
- CONFIG_KEY_READ,
- CONFIG_KEY_RESUME_MAP,
- CONFIG_KEY_RUNTIME,
- CONFIG_KEY_SCRATCHPAD,
- CONFIG_KEY_SEND,
- CONFIG_KEY_TASK_ID,
- ERROR,
- INTERRUPT,
- NO_WRITES,
- NS_END,
- NS_SEP,
- NULL_TASK_ID,
- PREVIOUS,
- PULL,
- PUSH,
- RESERVED,
- RESUME,
- RETURN,
- TASKS,
- )
- from langgraph._internal._scratchpad import PregelScratchpad
- from langgraph._internal._typing import EMPTY_SEQ, MISSING
- from langgraph.channels.base import BaseChannel
- from langgraph.channels.topic import Topic
- from langgraph.channels.untracked_value import UntrackedValue
- from langgraph.constants import TAG_HIDDEN
- from langgraph.managed.base import ManagedValueMapping
- from langgraph.pregel._call import get_runnable_for_task, identifier
- from langgraph.pregel._io import read_channels
- from langgraph.pregel._log import logger
- from langgraph.pregel._read import INPUT_CACHE_KEY_TYPE, PregelNode
- from langgraph.runtime import DEFAULT_RUNTIME, Runtime
- from langgraph.types import (
- All,
- CacheKey,
- CachePolicy,
- PregelExecutableTask,
- PregelTask,
- RetryPolicy,
- Send,
- )
- GetNextVersion = Callable[[V | None, None], V]
- SUPPORTS_EXC_NOTES = sys.version_info >= (3, 11)
- class WritesProtocol(Protocol):
- """Protocol for objects containing writes to be applied to checkpoint.
- Implemented by PregelTaskWrites and PregelExecutableTask."""
- @property
- def path(self) -> tuple[str | int | tuple, ...]: ...
- @property
- def name(self) -> str: ...
- @property
- def writes(self) -> Sequence[tuple[str, Any]]: ...
- @property
- def triggers(self) -> Sequence[str]: ...
- class PregelTaskWrites(NamedTuple):
- """Simplest implementation of WritesProtocol, for usage with writes that
- don't originate from a runnable task, eg. graph input, update_state, etc."""
- path: tuple[str | int | tuple, ...]
- name: str
- writes: Sequence[tuple[str, Any]]
- triggers: Sequence[str]
- class Call:
- __slots__ = ("func", "input", "retry_policy", "cache_policy", "callbacks")
- func: Callable
- input: tuple[tuple[Any, ...], dict[str, Any]]
- retry_policy: Sequence[RetryPolicy] | None
- cache_policy: CachePolicy | None
- callbacks: Callbacks
- def __init__(
- self,
- func: Callable,
- input: tuple[tuple[Any, ...], dict[str, Any]],
- *,
- retry_policy: Sequence[RetryPolicy] | None,
- cache_policy: CachePolicy | None,
- callbacks: Callbacks,
- ) -> None:
- self.func = func
- self.input = input
- self.retry_policy = retry_policy
- self.cache_policy = cache_policy
- self.callbacks = callbacks
- def should_interrupt(
- checkpoint: Checkpoint,
- interrupt_nodes: All | Sequence[str],
- tasks: Iterable[PregelExecutableTask],
- ) -> list[PregelExecutableTask]:
- """Check if the graph should be interrupted based on current state."""
- version_type = type(next(iter(checkpoint["channel_versions"].values()), None))
- null_version = version_type() # type: ignore[misc]
- seen = checkpoint["versions_seen"].get(INTERRUPT, {})
- # interrupt if any channel has been updated since last interrupt
- any_updates_since_prev_interrupt = any(
- version > seen.get(chan, null_version) # type: ignore[operator]
- for chan, version in checkpoint["channel_versions"].items()
- )
- # and any triggered node is in interrupt_nodes list
- return (
- [
- task
- for task in tasks
- if (
- (
- not task.config
- or TAG_HIDDEN not in task.config.get("tags", EMPTY_SEQ)
- )
- if interrupt_nodes == "*"
- else task.name in interrupt_nodes
- )
- ]
- if any_updates_since_prev_interrupt
- else []
- )
- def local_read(
- scratchpad: PregelScratchpad,
- channels: Mapping[str, BaseChannel],
- managed: ManagedValueMapping,
- task: WritesProtocol,
- select: list[str] | str,
- fresh: bool = False,
- ) -> dict[str, Any] | Any:
- """Function injected under CONFIG_KEY_READ in task config, to read current state.
- Used by conditional edges to read a copy of the state with reflecting the writes
- from that node only."""
- updated: dict[str, list[Any]] = defaultdict(list)
- if isinstance(select, str):
- managed_keys = []
- for c, v in task.writes:
- if c == select:
- updated[c].append(v)
- else:
- managed_keys = [k for k in select if k in managed]
- select = [k for k in select if k not in managed]
- for c, v in task.writes:
- if c in select:
- updated[c].append(v)
- if fresh:
- # apply writes
- local_channels: dict[str, BaseChannel] = {}
- for k in channels:
- cc = channels[k].copy()
- cc.update(updated[k])
- local_channels[k] = cc
- # read fresh values
- values = read_channels(local_channels, select)
- else:
- values = read_channels(channels, select)
- if managed_keys:
- values.update({k: managed[k].get(scratchpad) for k in managed_keys})
- return values
- def increment(current: int | None, channel: None) -> int:
- """Default channel versioning function, increments the current int version."""
- return current + 1 if current is not None else 1
- def apply_writes(
- checkpoint: Checkpoint,
- channels: Mapping[str, BaseChannel],
- tasks: Iterable[WritesProtocol],
- get_next_version: GetNextVersion | None,
- trigger_to_nodes: Mapping[str, Sequence[str]],
- ) -> set[str]:
- """Apply writes from a set of tasks (usually the tasks from a Pregel step)
- to the checkpoint and channels, and return managed values writes to be applied
- externally.
- Args:
- checkpoint: The checkpoint to update.
- channels: The channels to update.
- tasks: The tasks to apply writes from.
- get_next_version: Optional function to determine the next version of a channel.
- trigger_to_nodes: Mapping of channel names to the set of nodes that can be triggered by updates to that channel.
- Returns:
- Set of channels that were updated in this step.
- """
- # sort tasks on path, to ensure deterministic order for update application
- # any path parts after the 3rd are ignored for sorting
- # (we use them for eg. task ids which aren't good for sorting)
- tasks = sorted(tasks, key=lambda t: task_path_str(t.path[:3]))
- # if no task has triggers this is applying writes from the null task only
- # so we don't do anything other than update the channels written to
- bump_step = any(t.triggers for t in tasks)
- # update seen versions
- for task in tasks:
- checkpoint["versions_seen"].setdefault(task.name, {}).update(
- {
- chan: checkpoint["channel_versions"][chan]
- for chan in task.triggers
- if chan in checkpoint["channel_versions"]
- }
- )
- # Find the highest version of all channels
- if get_next_version is None:
- next_version = None
- else:
- next_version = get_next_version(
- (
- max(checkpoint["channel_versions"].values())
- if checkpoint["channel_versions"]
- else None
- ),
- None,
- )
- # Consume all channels that were read
- for chan in {
- chan
- for task in tasks
- for chan in task.triggers
- if chan not in RESERVED and chan in channels
- }:
- if channels[chan].consume() and next_version is not None:
- checkpoint["channel_versions"][chan] = next_version
- # Group writes by channel
- pending_writes_by_channel: dict[str, list[Any]] = defaultdict(list)
- for task in tasks:
- for chan, val in task.writes:
- if chan in (NO_WRITES, PUSH, RESUME, INTERRUPT, RETURN, ERROR):
- pass
- elif chan in channels:
- pending_writes_by_channel[chan].append(val)
- else:
- logger.warning(
- f"Task {task.name} with path {task.path} wrote to unknown channel {chan}, ignoring it."
- )
- # Apply writes to channels
- updated_channels: set[str] = set()
- for chan, vals in pending_writes_by_channel.items():
- if chan in channels:
- if channels[chan].update(vals) and next_version is not None:
- checkpoint["channel_versions"][chan] = next_version
- # unavailable channels can't trigger tasks, so don't add them
- if channels[chan].is_available():
- updated_channels.add(chan)
- # Channels that weren't updated in this step are notified of a new step
- if bump_step:
- for chan in channels:
- if channels[chan].is_available() and chan not in updated_channels:
- if channels[chan].update(EMPTY_SEQ) and next_version is not None:
- checkpoint["channel_versions"][chan] = next_version
- # unavailable channels can't trigger tasks, so don't add them
- if channels[chan].is_available():
- updated_channels.add(chan)
- # If this is (tentatively) the last superstep, notify all channels of finish
- if bump_step and updated_channels.isdisjoint(trigger_to_nodes):
- for chan in channels:
- if channels[chan].finish() and next_version is not None:
- checkpoint["channel_versions"][chan] = next_version
- # unavailable channels can't trigger tasks, so don't add them
- if channels[chan].is_available():
- updated_channels.add(chan)
- # Return managed values writes to be applied externally
- return updated_channels
- @overload
- def prepare_next_tasks(
- checkpoint: Checkpoint,
- pending_writes: list[PendingWrite],
- processes: Mapping[str, PregelNode],
- channels: Mapping[str, BaseChannel],
- managed: ManagedValueMapping,
- config: RunnableConfig,
- step: int,
- stop: int,
- *,
- for_execution: Literal[False],
- store: Literal[None] = None,
- checkpointer: Literal[None] = None,
- manager: Literal[None] = None,
- trigger_to_nodes: Mapping[str, Sequence[str]] | None = None,
- updated_channels: set[str] | None = None,
- retry_policy: Sequence[RetryPolicy] = (),
- cache_policy: Literal[None] = None,
- ) -> dict[str, PregelTask]: ...
- @overload
- def prepare_next_tasks(
- checkpoint: Checkpoint,
- pending_writes: list[PendingWrite],
- processes: Mapping[str, PregelNode],
- channels: Mapping[str, BaseChannel],
- managed: ManagedValueMapping,
- config: RunnableConfig,
- step: int,
- stop: int,
- *,
- for_execution: Literal[True],
- store: BaseStore | None,
- checkpointer: BaseCheckpointSaver | None,
- manager: None | ParentRunManager | AsyncParentRunManager,
- trigger_to_nodes: Mapping[str, Sequence[str]] | None = None,
- updated_channels: set[str] | None = None,
- retry_policy: Sequence[RetryPolicy] = (),
- cache_policy: CachePolicy | None = None,
- ) -> dict[str, PregelExecutableTask]: ...
- def prepare_next_tasks(
- checkpoint: Checkpoint,
- pending_writes: list[PendingWrite],
- processes: Mapping[str, PregelNode],
- channels: Mapping[str, BaseChannel],
- managed: ManagedValueMapping,
- config: RunnableConfig,
- step: int,
- stop: int,
- *,
- for_execution: bool,
- store: BaseStore | None = None,
- checkpointer: BaseCheckpointSaver | None = None,
- manager: None | ParentRunManager | AsyncParentRunManager = None,
- trigger_to_nodes: Mapping[str, Sequence[str]] | None = None,
- updated_channels: set[str] | None = None,
- retry_policy: Sequence[RetryPolicy] = (),
- cache_policy: CachePolicy | None = None,
- ) -> dict[str, PregelTask] | dict[str, PregelExecutableTask]:
- """Prepare the set of tasks that will make up the next Pregel step.
- Args:
- checkpoint: The current checkpoint.
- pending_writes: The list of pending writes.
- processes: The mapping of process names to PregelNode instances.
- channels: The mapping of channel names to BaseChannel instances.
- managed: The mapping of managed value names to functions.
- config: The `Runnable` configuration.
- step: The current step.
- for_execution: Whether the tasks are being prepared for execution.
- store: An instance of BaseStore to make it available for usage within tasks.
- checkpointer: `Checkpointer` instance used for saving checkpoints.
- manager: The parent run manager to use for the tasks.
- trigger_to_nodes: Optional: Mapping of channel names to the set of nodes
- that are can be triggered by that channel.
- updated_channels: Optional. Set of channel names that have been updated during
- the previous step. Using in conjunction with trigger_to_nodes to speed
- up the process of determining which nodes should be triggered in the next
- step.
- Returns:
- A dictionary of tasks to be executed. The keys are the task ids and the values
- are the tasks themselves. This is the union of all PUSH tasks (Sends)
- and PULL tasks (nodes triggered by edges).
- """
- input_cache: dict[INPUT_CACHE_KEY_TYPE, Any] = {}
- checkpoint_id_bytes = binascii.unhexlify(checkpoint["id"].replace("-", ""))
- null_version = checkpoint_null_version(checkpoint)
- tasks: list[PregelTask | PregelExecutableTask] = []
- # Consume pending tasks
- tasks_channel = cast(Topic[Send] | None, channels.get(TASKS))
- if tasks_channel and tasks_channel.is_available():
- for idx, _ in enumerate(tasks_channel.get()):
- if task := prepare_single_task(
- (PUSH, idx),
- None,
- checkpoint=checkpoint,
- checkpoint_id_bytes=checkpoint_id_bytes,
- checkpoint_null_version=null_version,
- pending_writes=pending_writes,
- processes=processes,
- channels=channels,
- managed=managed,
- config=config,
- step=step,
- stop=stop,
- for_execution=for_execution,
- store=store,
- checkpointer=checkpointer,
- manager=manager,
- input_cache=input_cache,
- cache_policy=cache_policy,
- retry_policy=retry_policy,
- ):
- tasks.append(task)
- # This section is an optimization that allows which nodes will be active
- # during the next step.
- # When there's information about:
- # 1. Which channels were updated in the previous step
- # 2. Which nodes are triggered by which channels
- # Then we can determine which nodes should be triggered in the next step
- # without having to cycle through all nodes.
- if updated_channels and trigger_to_nodes:
- triggered_nodes: set[str] = set()
- # Get all nodes that have triggers associated with an updated channel
- for channel in updated_channels:
- if node_ids := trigger_to_nodes.get(channel):
- triggered_nodes.update(node_ids)
- # Sort the nodes to ensure deterministic order
- candidate_nodes: Iterable[str] = sorted(triggered_nodes)
- elif not checkpoint["channel_versions"]:
- candidate_nodes = ()
- else:
- candidate_nodes = processes.keys()
- # Check if any processes should be run in next step
- # If so, prepare the values to be passed to them
- for name in candidate_nodes:
- if task := prepare_single_task(
- (PULL, name),
- None,
- checkpoint=checkpoint,
- checkpoint_id_bytes=checkpoint_id_bytes,
- checkpoint_null_version=null_version,
- pending_writes=pending_writes,
- processes=processes,
- channels=channels,
- managed=managed,
- config=config,
- step=step,
- stop=stop,
- for_execution=for_execution,
- store=store,
- checkpointer=checkpointer,
- manager=manager,
- input_cache=input_cache,
- cache_policy=cache_policy,
- retry_policy=retry_policy,
- ):
- tasks.append(task)
- return {t.id: t for t in tasks}
- PUSH_TRIGGER = (PUSH,)
- class _TaskIDFn(Protocol):
- def __call__(self, namespace: bytes, *parts: str | bytes) -> str:
- pass
- def prepare_single_task(
- task_path: tuple[Any, ...],
- task_id_checksum: str | None,
- *,
- checkpoint: Checkpoint,
- checkpoint_id_bytes: bytes,
- checkpoint_null_version: V | None,
- pending_writes: list[PendingWrite],
- processes: Mapping[str, PregelNode],
- channels: Mapping[str, BaseChannel],
- managed: ManagedValueMapping,
- config: RunnableConfig,
- step: int,
- stop: int,
- for_execution: bool,
- store: BaseStore | None = None,
- checkpointer: BaseCheckpointSaver | None = None,
- manager: None | ParentRunManager | AsyncParentRunManager = None,
- input_cache: dict[INPUT_CACHE_KEY_TYPE, Any] | None = None,
- cache_policy: CachePolicy | None = None,
- retry_policy: Sequence[RetryPolicy] = (),
- ) -> None | PregelTask | PregelExecutableTask:
- """Prepares a single task for the next Pregel step, given a task path, which
- uniquely identifies a PUSH or PULL task within the graph."""
- configurable = config.get(CONF, {})
- parent_ns = configurable.get(CONFIG_KEY_CHECKPOINT_NS, "")
- task_id_func = _xxhash_str if checkpoint["v"] > 1 else _uuid5_str
- if task_path[0] == PUSH and isinstance(task_path[-1], Call):
- return prepare_push_task_functional(
- cast(tuple[str, tuple, int, str, Call], task_path),
- task_id_checksum,
- checkpoint=checkpoint,
- checkpoint_id_bytes=checkpoint_id_bytes,
- pending_writes=pending_writes,
- channels=channels,
- managed=managed,
- config=config,
- step=step,
- stop=stop,
- for_execution=for_execution,
- store=store,
- checkpointer=checkpointer,
- manager=manager,
- cache_policy=cache_policy,
- retry_policy=retry_policy,
- parent_ns=parent_ns,
- task_id_func=task_id_func,
- )
- elif task_path[0] == PUSH:
- return prepare_push_task_send(
- cast(tuple[str, tuple], task_path),
- task_id_checksum,
- checkpoint=checkpoint,
- checkpoint_id_bytes=checkpoint_id_bytes,
- pending_writes=pending_writes,
- channels=channels,
- managed=managed,
- config=config,
- step=step,
- processes=processes,
- stop=stop,
- for_execution=for_execution,
- store=store,
- checkpointer=checkpointer,
- manager=manager,
- cache_policy=cache_policy,
- retry_policy=retry_policy,
- parent_ns=parent_ns,
- task_id_func=task_id_func,
- )
- elif task_path[0] == PULL:
- # (PULL, node name)
- name = cast(str, task_path[1])
- if name not in processes:
- return
- proc = processes[name]
- if checkpoint_null_version is None:
- return
- # If any of the channels read by this process were updated
- if _triggers(
- channels,
- checkpoint["channel_versions"],
- checkpoint["versions_seen"].get(name),
- checkpoint_null_version,
- proc,
- ):
- triggers = tuple(sorted(proc.triggers))
- # create task id
- checkpoint_ns = f"{parent_ns}{NS_SEP}{name}" if parent_ns else name
- task_id = task_id_func(
- checkpoint_id_bytes,
- checkpoint_ns,
- str(step),
- name,
- PULL,
- *triggers,
- )
- task_checkpoint_ns = f"{checkpoint_ns}{NS_END}{task_id}"
- # create scratchpad
- scratchpad = _scratchpad(
- config[CONF].get(CONFIG_KEY_SCRATCHPAD),
- pending_writes,
- task_id,
- xxh3_128_hexdigest(task_checkpoint_ns.encode()),
- config[CONF].get(CONFIG_KEY_RESUME_MAP),
- step,
- stop,
- )
- # create task input
- try:
- val = _proc_input(
- proc,
- managed,
- channels,
- for_execution=for_execution,
- input_cache=input_cache,
- scratchpad=scratchpad,
- )
- if val is MISSING:
- return
- except Exception as exc:
- if SUPPORTS_EXC_NOTES:
- exc.add_note(
- f"Before task with name '{name}' and path '{task_path[:3]}'"
- )
- raise
- metadata = {
- "langgraph_step": step,
- "langgraph_node": name,
- "langgraph_triggers": triggers,
- "langgraph_path": task_path[:3],
- "langgraph_checkpoint_ns": task_checkpoint_ns,
- }
- if task_id_checksum is not None:
- assert task_id == task_id_checksum, f"{task_id} != {task_id_checksum}"
- if for_execution:
- if node := proc.node:
- if proc.metadata:
- metadata.update(proc.metadata)
- writes: deque[tuple[str, Any]] = deque()
- cache_policy = proc.cache_policy or cache_policy
- if cache_policy:
- args_key = cache_policy.key_func(val)
- cache_key = CacheKey(
- (
- CACHE_NS_WRITES,
- (identifier(proc) or "__dynamic__"),
- name,
- ),
- xxh3_128_hexdigest(
- (
- args_key.encode()
- if isinstance(args_key, str)
- else args_key
- ),
- ),
- cache_policy.ttl,
- )
- else:
- cache_key = None
- runtime = cast(
- Runtime, configurable.get(CONFIG_KEY_RUNTIME, DEFAULT_RUNTIME)
- )
- runtime = runtime.override(
- previous=checkpoint["channel_values"].get(PREVIOUS, None),
- store=store,
- )
- additional_config = {
- "metadata": metadata,
- "tags": proc.tags,
- }
- return PregelExecutableTask(
- name,
- val,
- node,
- writes,
- patch_config(
- merge_configs(
- config, cast(RunnableConfig, additional_config)
- ),
- run_name=name,
- callbacks=(
- manager.get_child(f"graph:step:{step}")
- if manager
- else None
- ),
- configurable={
- CONFIG_KEY_TASK_ID: task_id,
- # deque.extend is thread-safe
- CONFIG_KEY_SEND: writes.extend,
- CONFIG_KEY_READ: partial(
- local_read,
- scratchpad,
- channels,
- managed,
- PregelTaskWrites(
- task_path[:3],
- name,
- writes,
- triggers,
- ),
- ),
- CONFIG_KEY_CHECKPOINTER: (
- checkpointer
- or configurable.get(CONFIG_KEY_CHECKPOINTER)
- ),
- CONFIG_KEY_CHECKPOINT_MAP: {
- **configurable.get(CONFIG_KEY_CHECKPOINT_MAP, {}),
- parent_ns: checkpoint["id"],
- },
- CONFIG_KEY_CHECKPOINT_ID: None,
- CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns,
- CONFIG_KEY_SCRATCHPAD: scratchpad,
- CONFIG_KEY_RUNTIME: runtime,
- },
- ),
- triggers,
- proc.retry_policy or retry_policy,
- cache_key,
- task_id,
- task_path[:3],
- writers=proc.flat_writers,
- subgraphs=proc.subgraphs,
- )
- else:
- return PregelTask(task_id, name, task_path[:3])
- def prepare_push_task_functional(
- task_path: tuple[str, tuple, int, str, Call],
- # (PUSH, parent task path, idx of PUSH write, id of parent task, Call)
- task_id_checksum: str | None,
- *,
- checkpoint: Checkpoint,
- checkpoint_id_bytes: bytes,
- pending_writes: list[PendingWrite],
- channels: Mapping[str, BaseChannel],
- managed: ManagedValueMapping,
- config: RunnableConfig,
- step: int,
- stop: int,
- for_execution: bool,
- store: BaseStore | None = None,
- checkpointer: BaseCheckpointSaver | None = None,
- manager: None | ParentRunManager | AsyncParentRunManager = None,
- cache_policy: CachePolicy | None = None,
- retry_policy: Sequence[RetryPolicy] = (),
- parent_ns: str,
- # namespace: bytes, *parts: str | bytes
- task_id_func: _TaskIDFn,
- ) -> PregelTask | PregelExecutableTask:
- """Prepare a push task with an attached caller. Used for the functional API."""
- configurable = config.get(CONF, {})
- call = task_path[-1]
- proc_ = get_runnable_for_task(call.func)
- name = proc_.name
- if name is None:
- raise ValueError("`call` functions must have a `__name__` attribute")
- # create task id
- triggers: Sequence[str] = PUSH_TRIGGER
- checkpoint_ns = f"{parent_ns}{NS_SEP}{name}" if parent_ns else name
- task_id = task_id_func(
- checkpoint_id_bytes,
- checkpoint_ns,
- str(step),
- name,
- PUSH,
- task_path_str(task_path[1]),
- str(task_path[2]),
- )
- task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
- # we append True to the task path to indicate that a call is being
- # made, so we should not return interrupts from this task (responsibility lies with the parent)
- in_progress_task_path = (*task_path[:3], True)
- metadata = {
- "langgraph_step": step,
- "langgraph_node": name,
- "langgraph_triggers": triggers,
- "langgraph_path": in_progress_task_path,
- "langgraph_checkpoint_ns": task_checkpoint_ns,
- }
- if task_id_checksum is not None:
- assert task_id == task_id_checksum, f"{task_id} != {task_id_checksum}"
- if for_execution:
- writes: deque[tuple[str, Any]] = deque()
- cache_policy = call.cache_policy or cache_policy
- if cache_policy:
- args_key = cache_policy.key_func(*call.input[0], **call.input[1])
- cache_key: CacheKey | None = CacheKey(
- (
- CACHE_NS_WRITES,
- (identifier(call.func) or "__dynamic__"),
- ),
- xxh3_128_hexdigest(
- args_key.encode() if isinstance(args_key, str) else args_key,
- ),
- cache_policy.ttl,
- )
- else:
- cache_key = None
- scratchpad = _scratchpad(
- configurable.get(CONFIG_KEY_SCRATCHPAD),
- pending_writes,
- task_id,
- xxh3_128_hexdigest(task_checkpoint_ns.encode()),
- configurable.get(CONFIG_KEY_RESUME_MAP),
- step,
- stop,
- )
- runtime = cast(Runtime, configurable.get(CONFIG_KEY_RUNTIME, DEFAULT_RUNTIME))
- runtime = runtime.override(store=store)
- return PregelExecutableTask(
- name,
- call.input,
- proc_,
- writes,
- patch_config(
- merge_configs(config, {"metadata": metadata}),
- run_name=name,
- callbacks=call.callbacks
- or (manager.get_child(f"graph:step:{step}") if manager else None),
- configurable={
- CONFIG_KEY_TASK_ID: task_id,
- # deque.extend is thread-safe
- CONFIG_KEY_SEND: writes.extend,
- CONFIG_KEY_READ: partial(
- local_read,
- scratchpad,
- channels,
- managed,
- PregelTaskWrites(in_progress_task_path, name, writes, triggers),
- ),
- CONFIG_KEY_CHECKPOINTER: (
- checkpointer or configurable.get(CONFIG_KEY_CHECKPOINTER)
- ),
- CONFIG_KEY_CHECKPOINT_MAP: {
- **configurable.get(CONFIG_KEY_CHECKPOINT_MAP, {}),
- parent_ns: checkpoint["id"],
- },
- CONFIG_KEY_CHECKPOINT_ID: None,
- CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns,
- CONFIG_KEY_SCRATCHPAD: scratchpad,
- CONFIG_KEY_RUNTIME: runtime,
- },
- ),
- triggers,
- call.retry_policy or retry_policy,
- cache_key,
- task_id,
- in_progress_task_path,
- )
- else:
- return PregelTask(task_id, name, in_progress_task_path)
- def prepare_push_task_send(
- task_path: tuple[str, tuple],
- # (PUSH, parent task path)
- task_id_checksum: str | None,
- *,
- checkpoint: Checkpoint,
- checkpoint_id_bytes: bytes,
- pending_writes: list[PendingWrite],
- channels: Mapping[str, BaseChannel],
- managed: ManagedValueMapping,
- config: RunnableConfig,
- step: int,
- stop: int,
- for_execution: bool,
- store: BaseStore | None = None,
- checkpointer: BaseCheckpointSaver | None = None,
- manager: None | ParentRunManager | AsyncParentRunManager = None,
- cache_policy: CachePolicy | None = None,
- retry_policy: Sequence[RetryPolicy] = (),
- parent_ns: str,
- task_id_func: _TaskIDFn,
- processes: Mapping[str, PregelNode],
- ) -> PregelTask | PregelExecutableTask | None:
- if len(task_path) == 2:
- # SEND tasks, executed in superstep n+1
- # (PUSH, idx of pending send)
- idx = cast(int, task_path[1])
- if not channels[TASKS].is_available():
- return
- sends: Sequence[Send] = channels[TASKS].get()
- if idx < 0 or idx >= len(sends):
- return
- packet = sends[idx]
- if not isinstance(packet, Send):
- logger.warning(
- f"Ignoring invalid packet type {type(packet)} in pending sends"
- )
- return
- if packet.node not in processes:
- logger.warning(f"Ignoring unknown node name {packet.node} in pending sends")
- return
- # find process
- proc = processes[packet.node]
- proc_node = proc.node
- if proc_node is None:
- return
- # create task id
- triggers = PUSH_TRIGGER
- checkpoint_ns = (
- f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node
- )
- task_id = task_id_func(
- checkpoint_id_bytes,
- checkpoint_ns,
- str(step),
- packet.node,
- PUSH,
- str(idx),
- )
- else:
- logger.warning(f"Ignoring invalid PUSH task path {task_path}")
- return
- configurable = config.get(CONF, {})
- task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
- # we append False to the task path to indicate that a call is not being made
- # so we should return interrupts from this task
- translated_task_path = (*task_path[:3], False)
- metadata = {
- "langgraph_step": step,
- "langgraph_node": packet.node,
- "langgraph_triggers": triggers,
- "langgraph_path": translated_task_path,
- "langgraph_checkpoint_ns": task_checkpoint_ns,
- }
- if task_id_checksum is not None:
- assert task_id == task_id_checksum, f"{task_id} != {task_id_checksum}"
- if for_execution:
- if proc.metadata:
- metadata.update(proc.metadata)
- writes: deque[tuple[str, Any]] = deque()
- cache_policy = proc.cache_policy or cache_policy
- if cache_policy:
- args_key = cache_policy.key_func(packet.arg)
- cache_key = CacheKey(
- (
- CACHE_NS_WRITES,
- (identifier(proc) or "__dynamic__"),
- packet.node,
- ),
- xxh3_128_hexdigest(
- args_key.encode() if isinstance(args_key, str) else args_key,
- ),
- cache_policy.ttl,
- )
- else:
- cache_key = None
- scratchpad = _scratchpad(
- config[CONF].get(CONFIG_KEY_SCRATCHPAD),
- pending_writes,
- task_id,
- xxh3_128_hexdigest(task_checkpoint_ns.encode()),
- config[CONF].get(CONFIG_KEY_RESUME_MAP),
- step,
- stop,
- )
- runtime = cast(Runtime, configurable.get(CONFIG_KEY_RUNTIME, DEFAULT_RUNTIME))
- runtime = runtime.override(
- store=store, previous=checkpoint["channel_values"].get(PREVIOUS, None)
- )
- additional_config: RunnableConfig = {
- "metadata": metadata,
- "tags": proc.tags,
- }
- return PregelExecutableTask(
- packet.node,
- packet.arg,
- proc_node,
- writes,
- patch_config(
- merge_configs(config, additional_config),
- run_name=packet.node,
- callbacks=(
- manager.get_child(f"graph:step:{step}") if manager else None
- ),
- configurable={
- CONFIG_KEY_TASK_ID: task_id,
- # deque.extend is thread-safe
- CONFIG_KEY_SEND: writes.extend,
- CONFIG_KEY_READ: partial(
- local_read,
- scratchpad,
- channels,
- managed,
- PregelTaskWrites(
- translated_task_path, packet.node, writes, triggers
- ),
- ),
- CONFIG_KEY_CHECKPOINTER: (
- checkpointer or configurable.get(CONFIG_KEY_CHECKPOINTER)
- ),
- CONFIG_KEY_CHECKPOINT_MAP: {
- **configurable.get(CONFIG_KEY_CHECKPOINT_MAP, {}),
- parent_ns: checkpoint["id"],
- },
- CONFIG_KEY_CHECKPOINT_ID: None,
- CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns,
- CONFIG_KEY_SCRATCHPAD: scratchpad,
- CONFIG_KEY_RUNTIME: runtime,
- },
- ),
- triggers,
- proc.retry_policy or retry_policy,
- cache_key,
- task_id,
- translated_task_path,
- writers=proc.flat_writers,
- subgraphs=proc.subgraphs,
- )
- else:
- return PregelTask(task_id, packet.node, translated_task_path)
- def checkpoint_null_version(
- checkpoint: Checkpoint,
- ) -> V | None:
- """Get the null version for the checkpoint, if available."""
- for version in checkpoint["channel_versions"].values():
- return type(version)()
- return None
- def _triggers(
- channels: Mapping[str, BaseChannel],
- versions: ChannelVersions,
- seen: ChannelVersions | None,
- null_version: V,
- proc: PregelNode,
- ) -> bool:
- if seen is None:
- for chan in proc.triggers:
- if channels[chan].is_available():
- return True
- else:
- for chan in proc.triggers:
- if channels[chan].is_available() and versions.get( # type: ignore[operator]
- chan, null_version
- ) > seen.get(chan, null_version):
- return True
- return False
- def _scratchpad(
- parent_scratchpad: PregelScratchpad | None,
- pending_writes: list[PendingWrite],
- task_id: str,
- namespace_hash: str,
- resume_map: dict[str, Any] | None,
- step: int,
- stop: int,
- ) -> PregelScratchpad:
- if len(pending_writes) > 0:
- # find global resume value
- for w in pending_writes:
- if w[0] == NULL_TASK_ID and w[1] == RESUME:
- null_resume_write = w
- break
- else:
- # None cannot be used as a resume value, because it would be difficult to
- # distinguish from missing when used over http
- null_resume_write = None
- # find task-specific resume value
- for w in pending_writes:
- if w[0] == task_id and w[1] == RESUME:
- task_resume_write = w[2]
- if not isinstance(task_resume_write, list):
- task_resume_write = [task_resume_write]
- break
- else:
- task_resume_write = []
- del w
- # find namespace and task-specific resume value
- if resume_map and namespace_hash in resume_map:
- mapped_resume_write = resume_map[namespace_hash]
- task_resume_write.append(mapped_resume_write)
- else:
- null_resume_write = None
- task_resume_write = []
- def get_null_resume(consume: bool = False) -> Any:
- if null_resume_write is None:
- if parent_scratchpad is not None:
- return parent_scratchpad.get_null_resume(consume)
- return None
- if consume:
- try:
- pending_writes.remove(null_resume_write)
- return null_resume_write[2]
- except ValueError:
- return None
- return null_resume_write[2]
- # using itertools.count as an atomic counter (+= 1 is not thread-safe)
- return PregelScratchpad(
- step=step,
- stop=stop,
- # call
- call_counter=LazyAtomicCounter(),
- # interrupt
- interrupt_counter=LazyAtomicCounter(),
- resume=task_resume_write,
- get_null_resume=get_null_resume,
- # subgraph
- subgraph_counter=LazyAtomicCounter(),
- )
- def _proc_input(
- proc: PregelNode,
- managed: ManagedValueMapping,
- channels: Mapping[str, BaseChannel],
- *,
- for_execution: bool,
- scratchpad: PregelScratchpad,
- input_cache: dict[INPUT_CACHE_KEY_TYPE, Any] | None,
- ) -> Any:
- """Prepare input for a PULL task, based on the process's channels and triggers."""
- # if in cache return shallow copy
- if input_cache is not None and proc.input_cache_key in input_cache:
- return copy(input_cache[proc.input_cache_key])
- # If all trigger channels subscribed by this process are not empty
- # then invoke the process with the values of all non-empty channels
- if isinstance(proc.channels, list):
- val: dict[str, Any] = {}
- for chan in proc.channels:
- if chan in channels:
- if channels[chan].is_available():
- val[chan] = channels[chan].get()
- else:
- val[chan] = managed[chan].get(scratchpad)
- elif isinstance(proc.channels, str):
- if proc.channels in channels:
- if channels[proc.channels].is_available():
- val = channels[proc.channels].get()
- else:
- return MISSING
- else:
- return MISSING
- else:
- raise RuntimeError(
- f"Invalid channels type, expected list or dict, got {proc.channels}"
- )
- # If the process has a mapper, apply it to the value
- if for_execution and proc.mapper is not None:
- val = proc.mapper(val)
- # Cache the input value
- if input_cache is not None:
- input_cache[proc.input_cache_key] = val
- return val
- def _uuid5_str(namespace: bytes, *parts: str | bytes) -> str:
- """Generate a UUID from the SHA-1 hash of a namespace and str parts."""
- sha = sha1(namespace, usedforsecurity=False)
- sha.update(b"".join(p.encode() if isinstance(p, str) else p for p in parts))
- hex = sha.hexdigest()
- return f"{hex[:8]}-{hex[8:12]}-{hex[12:16]}-{hex[16:20]}-{hex[20:32]}"
- def _xxhash_str(namespace: bytes, *parts: str | bytes) -> str:
- """Generate a UUID from the XXH3 hash of a namespace and str parts."""
- hex = xxh3_128_hexdigest(
- namespace + b"".join(p.encode() if isinstance(p, str) else p for p in parts)
- )
- return f"{hex[:8]}-{hex[8:12]}-{hex[12:16]}-{hex[16:20]}-{hex[20:32]}"
- def task_path_str(tup: str | int | tuple) -> str:
- """Generate a string representation of the task path."""
- return (
- f"~{', '.join(task_path_str(x) for x in tup)}"
- if isinstance(tup, (tuple, list))
- else f"{tup:010d}"
- if isinstance(tup, int)
- else str(tup)
- )
- LAZY_ATOMIC_COUNTER_LOCK = threading.Lock()
- class LazyAtomicCounter:
- __slots__ = ("_counter",)
- _counter: Callable[[], int] | None
- def __init__(self) -> None:
- self._counter = None
- def __call__(self) -> int:
- if self._counter is None:
- with LAZY_ATOMIC_COUNTER_LOCK:
- if self._counter is None:
- self._counter = itertools.count(0).__next__
- return self._counter()
- def sanitize_untracked_values_in_send(
- packet: Send, channels: Mapping[str, BaseChannel]
- ) -> Send:
- """Pop any values belonging to UntrackedValue channels in Send.arg for safe checkpointing.
- Send is often called with state to be passed to the dest node, which may contain
- UntrackedValues at the top level. Send is not typed and arg may be a nested dict."""
- if not isinstance(packet.arg, dict):
- # Command
- return packet
- # top level keys should be the channel names
- sanitized_arg = {
- k: v
- for k, v in packet.arg.items()
- if not isinstance(channels.get(k), UntrackedValue)
- }
- return Send(node=packet.node, arg=sanitized_arg)
|