from __future__ import annotations from collections.abc import Iterable, Iterator, Mapping, Sequence from dataclasses import asdict from typing import Any from uuid import UUID from langchain_core.runnables import RunnableConfig from langgraph.checkpoint.base import CheckpointMetadata, PendingWrite from typing_extensions import TypedDict from langgraph._internal._config import patch_checkpoint_map from langgraph._internal._constants import ( CONF, CONFIG_KEY_CHECKPOINT_NS, ERROR, INTERRUPT, NS_END, NS_SEP, RETURN, ) from langgraph._internal._typing import MISSING from langgraph.channels.base import BaseChannel from langgraph.constants import TAG_HIDDEN from langgraph.pregel._io import read_channels from langgraph.types import PregelExecutableTask, PregelTask, StateSnapshot __all__ = ("TaskPayload", "TaskResultPayload", "CheckpointTask", "CheckpointPayload") class TaskPayload(TypedDict): id: str name: str input: Any triggers: list[str] class TaskResultPayload(TypedDict): id: str name: str error: str | None interrupts: list[dict] result: dict[str, Any] class CheckpointTask(TypedDict): id: str name: str error: str | None interrupts: list[dict] state: StateSnapshot | RunnableConfig | None class CheckpointPayload(TypedDict): config: RunnableConfig | None metadata: CheckpointMetadata values: dict[str, Any] next: list[str] parent_config: RunnableConfig | None tasks: list[CheckpointTask] TASK_NAMESPACE = UUID("6ba7b831-9dad-11d1-80b4-00c04fd430c8") def map_debug_tasks(tasks: Iterable[PregelExecutableTask]) -> Iterator[TaskPayload]: """Produce "task" events for stream_mode=debug.""" for task in tasks: if task.config is not None and TAG_HIDDEN in task.config.get("tags", []): continue yield { "id": task.id, "name": task.name, "input": task.input, "triggers": task.triggers, } def is_multiple_channel_write(value: Any) -> bool: """Return True if the payload already wraps multiple writes from the same channel.""" return ( isinstance(value, dict) and "$writes" in value and isinstance(value["$writes"], list) ) def map_task_result_writes(writes: Sequence[tuple[str, Any]]) -> dict[str, Any]: """Folds task writes into a result dict and aggregates multiple writes to the same channel. If the channel contains a single write, we record the write in the result dict as `{channel: write}` If the channel contains multiple writes, we record the writes in the result dict as `{channel: {'$writes': [write1, write2, ...]}}`""" result: dict[str, Any] = {} for channel, value in writes: existing = result.get(channel) if existing is not None: channel_writes = ( existing["$writes"] if is_multiple_channel_write(existing) else [existing] ) channel_writes.append(value) result[channel] = {"$writes": channel_writes} else: result[channel] = value return result def map_debug_task_results( task_tup: tuple[PregelExecutableTask, Sequence[tuple[str, Any]]], stream_keys: str | Sequence[str], ) -> Iterator[TaskResultPayload]: """Produce "task_result" events for stream_mode=debug.""" stream_channels_list = ( [stream_keys] if isinstance(stream_keys, str) else stream_keys ) task, writes = task_tup yield { "id": task.id, "name": task.name, "error": next((w[1] for w in writes if w[0] == ERROR), None), "result": map_task_result_writes( [w for w in writes if w[0] in stream_channels_list or w[0] == RETURN] ), "interrupts": [ asdict(v) for w in writes if w[0] == INTERRUPT for v in (w[1] if isinstance(w[1], Sequence) else [w[1]]) ], } def rm_pregel_keys(config: RunnableConfig | None) -> RunnableConfig | None: """Remove pregel-specific keys from the config.""" if config is None: return config return { "configurable": { k: v for k, v in config.get("configurable", {}).items() if not k.startswith("__pregel_") } } def map_debug_checkpoint( config: RunnableConfig, channels: Mapping[str, BaseChannel], stream_channels: str | Sequence[str], metadata: CheckpointMetadata, tasks: Iterable[PregelExecutableTask], pending_writes: list[PendingWrite], parent_config: RunnableConfig | None, output_keys: str | Sequence[str], ) -> Iterator[CheckpointPayload]: """Produce "checkpoint" events for stream_mode=debug.""" parent_ns = config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") task_states: dict[str, RunnableConfig | StateSnapshot] = {} for task in tasks: if not task.subgraphs: continue # assemble checkpoint_ns for this task task_ns = f"{task.name}{NS_END}{task.id}" if parent_ns: task_ns = f"{parent_ns}{NS_SEP}{task_ns}" # set config as signal that subgraph checkpoints exist task_states[task.id] = { CONF: { "thread_id": config[CONF]["thread_id"], CONFIG_KEY_CHECKPOINT_NS: task_ns, } } yield { "config": rm_pregel_keys(patch_checkpoint_map(config, metadata)), "parent_config": rm_pregel_keys(patch_checkpoint_map(parent_config, metadata)), "values": read_channels(channels, stream_channels), "metadata": metadata, "next": [t.name for t in tasks], "tasks": [ { "id": t.id, "name": t.name, "error": t.error, "state": t.state, } if t.error else { "id": t.id, "name": t.name, "result": t.result, "interrupts": tuple(asdict(i) for i in t.interrupts), "state": t.state, } if t.result else { "id": t.id, "name": t.name, "interrupts": tuple(asdict(i) for i in t.interrupts), "state": t.state, } for t in tasks_w_writes(tasks, pending_writes, task_states, output_keys) ], } def tasks_w_writes( tasks: Iterable[PregelTask | PregelExecutableTask], pending_writes: list[PendingWrite] | None, states: dict[str, RunnableConfig | StateSnapshot] | None, output_keys: str | Sequence[str], ) -> tuple[PregelTask, ...]: """Apply writes / subgraph states to tasks to be returned in a StateSnapshot.""" pending_writes = pending_writes or [] out: list[PregelTask] = [] for task in tasks: rtn = next( ( val for tid, chan, val in pending_writes if tid == task.id and chan == RETURN ), MISSING, ) task_error = next( (exc for tid, n, exc in pending_writes if tid == task.id and n == ERROR), None, ) task_interrupts = tuple( v for tid, n, vv in pending_writes if tid == task.id and n == INTERRUPT for v in (vv if isinstance(vv, Sequence) else [vv]) ) task_writes = [ (chan, val) for tid, chan, val in pending_writes if tid == task.id and chan not in (ERROR, INTERRUPT, RETURN) ] if rtn is not MISSING: task_result = rtn elif isinstance(output_keys, str): # unwrap single channel writes to just the write value filtered_writes = [ (chan, val) for chan, val in task_writes if chan == output_keys ] mapped_writes = map_task_result_writes(filtered_writes) task_result = mapped_writes.get(str(output_keys)) if mapped_writes else None else: if isinstance(output_keys, str): output_keys = [output_keys] # map task result writes to the desired output channels # repeateed writes to the same channel are aggregated into: {'$writes': [write1, write2, ...]} filtered_writes = [ (chan, val) for chan, val in task_writes if chan in output_keys ] mapped_writes = map_task_result_writes(filtered_writes) task_result = mapped_writes if filtered_writes else {} has_writes = rtn is not MISSING or any( w[0] == task.id and w[1] not in (ERROR, INTERRUPT) for w in pending_writes ) out.append( PregelTask( task.id, task.name, task.path, task_error, task_interrupts, states.get(task.id) if states else None, task_result if has_writes else None, ) ) return tuple(out) COLOR_MAPPING = { "black": "0;30", "red": "0;31", "green": "0;32", "yellow": "0;33", "blue": "0;34", "magenta": "0;35", "cyan": "0;36", "white": "0;37", "gray": "1;30", } def get_colored_text(text: str, color: str) -> str: """Get colored text.""" return f"\033[1;3{COLOR_MAPPING[color]}m{text}\033[0m" def get_bolded_text(text: str) -> str: """Get bolded text.""" return f"\033[1m{text}\033[0m"