| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308 |
- from __future__ import annotations
- from collections.abc import Iterable, Iterator, Mapping, Sequence
- from dataclasses import asdict
- from typing import Any
- from uuid import UUID
- from langchain_core.runnables import RunnableConfig
- from langgraph.checkpoint.base import CheckpointMetadata, PendingWrite
- from typing_extensions import TypedDict
- from langgraph._internal._config import patch_checkpoint_map
- from langgraph._internal._constants import (
- CONF,
- CONFIG_KEY_CHECKPOINT_NS,
- ERROR,
- INTERRUPT,
- NS_END,
- NS_SEP,
- RETURN,
- )
- from langgraph._internal._typing import MISSING
- from langgraph.channels.base import BaseChannel
- from langgraph.constants import TAG_HIDDEN
- from langgraph.pregel._io import read_channels
- from langgraph.types import PregelExecutableTask, PregelTask, StateSnapshot
- __all__ = ("TaskPayload", "TaskResultPayload", "CheckpointTask", "CheckpointPayload")
- class TaskPayload(TypedDict):
- id: str
- name: str
- input: Any
- triggers: list[str]
- class TaskResultPayload(TypedDict):
- id: str
- name: str
- error: str | None
- interrupts: list[dict]
- result: dict[str, Any]
- class CheckpointTask(TypedDict):
- id: str
- name: str
- error: str | None
- interrupts: list[dict]
- state: StateSnapshot | RunnableConfig | None
- class CheckpointPayload(TypedDict):
- config: RunnableConfig | None
- metadata: CheckpointMetadata
- values: dict[str, Any]
- next: list[str]
- parent_config: RunnableConfig | None
- tasks: list[CheckpointTask]
- TASK_NAMESPACE = UUID("6ba7b831-9dad-11d1-80b4-00c04fd430c8")
- def map_debug_tasks(tasks: Iterable[PregelExecutableTask]) -> Iterator[TaskPayload]:
- """Produce "task" events for stream_mode=debug."""
- for task in tasks:
- if task.config is not None and TAG_HIDDEN in task.config.get("tags", []):
- continue
- yield {
- "id": task.id,
- "name": task.name,
- "input": task.input,
- "triggers": task.triggers,
- }
- def is_multiple_channel_write(value: Any) -> bool:
- """Return True if the payload already wraps multiple writes from the same channel."""
- return (
- isinstance(value, dict)
- and "$writes" in value
- and isinstance(value["$writes"], list)
- )
- def map_task_result_writes(writes: Sequence[tuple[str, Any]]) -> dict[str, Any]:
- """Folds task writes into a result dict and aggregates multiple writes to the same channel.
- If the channel contains a single write, we record the write in the result dict as `{channel: write}`
- If the channel contains multiple writes, we record the writes in the result dict as `{channel: {'$writes': [write1, write2, ...]}}`"""
- result: dict[str, Any] = {}
- for channel, value in writes:
- existing = result.get(channel)
- if existing is not None:
- channel_writes = (
- existing["$writes"]
- if is_multiple_channel_write(existing)
- else [existing]
- )
- channel_writes.append(value)
- result[channel] = {"$writes": channel_writes}
- else:
- result[channel] = value
- return result
- def map_debug_task_results(
- task_tup: tuple[PregelExecutableTask, Sequence[tuple[str, Any]]],
- stream_keys: str | Sequence[str],
- ) -> Iterator[TaskResultPayload]:
- """Produce "task_result" events for stream_mode=debug."""
- stream_channels_list = (
- [stream_keys] if isinstance(stream_keys, str) else stream_keys
- )
- task, writes = task_tup
- yield {
- "id": task.id,
- "name": task.name,
- "error": next((w[1] for w in writes if w[0] == ERROR), None),
- "result": map_task_result_writes(
- [w for w in writes if w[0] in stream_channels_list or w[0] == RETURN]
- ),
- "interrupts": [
- asdict(v)
- for w in writes
- if w[0] == INTERRUPT
- for v in (w[1] if isinstance(w[1], Sequence) else [w[1]])
- ],
- }
- def rm_pregel_keys(config: RunnableConfig | None) -> RunnableConfig | None:
- """Remove pregel-specific keys from the config."""
- if config is None:
- return config
- return {
- "configurable": {
- k: v
- for k, v in config.get("configurable", {}).items()
- if not k.startswith("__pregel_")
- }
- }
- def map_debug_checkpoint(
- config: RunnableConfig,
- channels: Mapping[str, BaseChannel],
- stream_channels: str | Sequence[str],
- metadata: CheckpointMetadata,
- tasks: Iterable[PregelExecutableTask],
- pending_writes: list[PendingWrite],
- parent_config: RunnableConfig | None,
- output_keys: str | Sequence[str],
- ) -> Iterator[CheckpointPayload]:
- """Produce "checkpoint" events for stream_mode=debug."""
- parent_ns = config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")
- task_states: dict[str, RunnableConfig | StateSnapshot] = {}
- for task in tasks:
- if not task.subgraphs:
- continue
- # assemble checkpoint_ns for this task
- task_ns = f"{task.name}{NS_END}{task.id}"
- if parent_ns:
- task_ns = f"{parent_ns}{NS_SEP}{task_ns}"
- # set config as signal that subgraph checkpoints exist
- task_states[task.id] = {
- CONF: {
- "thread_id": config[CONF]["thread_id"],
- CONFIG_KEY_CHECKPOINT_NS: task_ns,
- }
- }
- yield {
- "config": rm_pregel_keys(patch_checkpoint_map(config, metadata)),
- "parent_config": rm_pregel_keys(patch_checkpoint_map(parent_config, metadata)),
- "values": read_channels(channels, stream_channels),
- "metadata": metadata,
- "next": [t.name for t in tasks],
- "tasks": [
- {
- "id": t.id,
- "name": t.name,
- "error": t.error,
- "state": t.state,
- }
- if t.error
- else {
- "id": t.id,
- "name": t.name,
- "result": t.result,
- "interrupts": tuple(asdict(i) for i in t.interrupts),
- "state": t.state,
- }
- if t.result
- else {
- "id": t.id,
- "name": t.name,
- "interrupts": tuple(asdict(i) for i in t.interrupts),
- "state": t.state,
- }
- for t in tasks_w_writes(tasks, pending_writes, task_states, output_keys)
- ],
- }
- def tasks_w_writes(
- tasks: Iterable[PregelTask | PregelExecutableTask],
- pending_writes: list[PendingWrite] | None,
- states: dict[str, RunnableConfig | StateSnapshot] | None,
- output_keys: str | Sequence[str],
- ) -> tuple[PregelTask, ...]:
- """Apply writes / subgraph states to tasks to be returned in a StateSnapshot."""
- pending_writes = pending_writes or []
- out: list[PregelTask] = []
- for task in tasks:
- rtn = next(
- (
- val
- for tid, chan, val in pending_writes
- if tid == task.id and chan == RETURN
- ),
- MISSING,
- )
- task_error = next(
- (exc for tid, n, exc in pending_writes if tid == task.id and n == ERROR),
- None,
- )
- task_interrupts = tuple(
- v
- for tid, n, vv in pending_writes
- if tid == task.id and n == INTERRUPT
- for v in (vv if isinstance(vv, Sequence) else [vv])
- )
- task_writes = [
- (chan, val)
- for tid, chan, val in pending_writes
- if tid == task.id and chan not in (ERROR, INTERRUPT, RETURN)
- ]
- if rtn is not MISSING:
- task_result = rtn
- elif isinstance(output_keys, str):
- # unwrap single channel writes to just the write value
- filtered_writes = [
- (chan, val) for chan, val in task_writes if chan == output_keys
- ]
- mapped_writes = map_task_result_writes(filtered_writes)
- task_result = mapped_writes.get(str(output_keys)) if mapped_writes else None
- else:
- if isinstance(output_keys, str):
- output_keys = [output_keys]
- # map task result writes to the desired output channels
- # repeateed writes to the same channel are aggregated into: {'$writes': [write1, write2, ...]}
- filtered_writes = [
- (chan, val) for chan, val in task_writes if chan in output_keys
- ]
- mapped_writes = map_task_result_writes(filtered_writes)
- task_result = mapped_writes if filtered_writes else {}
- has_writes = rtn is not MISSING or any(
- w[0] == task.id and w[1] not in (ERROR, INTERRUPT) for w in pending_writes
- )
- out.append(
- PregelTask(
- task.id,
- task.name,
- task.path,
- task_error,
- task_interrupts,
- states.get(task.id) if states else None,
- task_result if has_writes else None,
- )
- )
- return tuple(out)
- COLOR_MAPPING = {
- "black": "0;30",
- "red": "0;31",
- "green": "0;32",
- "yellow": "0;33",
- "blue": "0;34",
- "magenta": "0;35",
- "cyan": "0;36",
- "white": "0;37",
- "gray": "1;30",
- }
- def get_colored_text(text: str, color: str) -> str:
- """Get colored text."""
- return f"\033[1;3{COLOR_MAPPING[color]}m{text}\033[0m"
- def get_bolded_text(text: str) -> str:
- """Get bolded text."""
- return f"\033[1m{text}\033[0m"
|