| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- from __future__ import annotations
- from collections.abc import Mapping, Sequence
- from typing import Any
- from langgraph._internal._constants import RESERVED
- from langgraph.channels.base import BaseChannel
- from langgraph.managed.base import ManagedValueMapping
- from langgraph.pregel._read import PregelNode
- from langgraph.types import All
- def validate_graph(
- nodes: Mapping[str, PregelNode],
- channels: dict[str, BaseChannel],
- managed: ManagedValueMapping,
- input_channels: str | Sequence[str],
- output_channels: str | Sequence[str],
- stream_channels: str | Sequence[str] | None,
- interrupt_after_nodes: All | Sequence[str],
- interrupt_before_nodes: All | Sequence[str],
- ) -> None:
- for chan in channels:
- if chan in RESERVED:
- raise ValueError(f"Channel name '{chan}' is reserved")
- for name in managed:
- if name in RESERVED:
- raise ValueError(f"Managed name '{name}' is reserved")
- subscribed_channels = set[str]()
- for name, node in nodes.items():
- if name in RESERVED:
- raise ValueError(f"Node name '{name}' is reserved")
- if isinstance(node, PregelNode):
- subscribed_channels.update(node.triggers)
- if isinstance(node.channels, str):
- if node.channels not in channels:
- raise ValueError(
- f"Node {name} reads channel '{node.channels}' "
- f"not in known channels: '{repr(sorted(channels))[:100]}'"
- )
- else:
- for chan in node.channels:
- if chan not in channels and chan not in managed:
- raise ValueError(
- f"Node {name} reads channel '{chan}' "
- f"not in known channels: '{repr(sorted(channels))[:100]}'"
- )
- else:
- raise TypeError(
- f"Invalid node type {type(node)}, expected PregelNode or NodeBuilder"
- )
- for chan in subscribed_channels:
- if chan not in channels:
- raise ValueError(
- f"Subscribed channel '{chan}' not "
- f"in known channels: '{repr(sorted(channels))[:100]}'"
- )
- if isinstance(input_channels, str):
- if input_channels not in channels:
- raise ValueError(
- f"Input channel '{input_channels}' not "
- f"in known channels: '{repr(sorted(channels))[:100]}'"
- )
- if input_channels not in subscribed_channels:
- raise ValueError(
- f"Input channel {input_channels} is not subscribed to by any node"
- )
- else:
- for chan in input_channels:
- if chan not in channels:
- raise ValueError(
- f"Input channel '{chan}' not in '{repr(sorted(channels))[:100]}'"
- )
- if all(chan not in subscribed_channels for chan in input_channels):
- raise ValueError(
- f"None of the input channels {input_channels} "
- f"are subscribed to by any node"
- )
- all_output_channels = set[str]()
- if isinstance(output_channels, str):
- all_output_channels.add(output_channels)
- else:
- all_output_channels.update(output_channels)
- if isinstance(stream_channels, str):
- all_output_channels.add(stream_channels)
- elif stream_channels is not None:
- all_output_channels.update(stream_channels)
- for chan in all_output_channels:
- if chan not in channels:
- raise ValueError(
- f"Output channel '{chan}' not "
- f"in known channels: '{repr(sorted(channels))[:100]}'"
- )
- if interrupt_after_nodes != "*":
- for n in interrupt_after_nodes:
- if n not in nodes:
- raise ValueError(f"Node {n} not in nodes")
- if interrupt_before_nodes != "*":
- for n in interrupt_before_nodes:
- if n not in nodes:
- raise ValueError(f"Node {n} not in nodes")
- def validate_keys(
- keys: str | Sequence[str] | None,
- channels: Mapping[str, Any],
- ) -> None:
- if isinstance(keys, str):
- if keys not in channels:
- raise ValueError(f"Key {keys} not in channels")
- elif keys is not None:
- for chan in keys:
- if chan not in channels:
- raise ValueError(f"Key {chan} not in channels")
|