| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015 |
- from __future__ import annotations
- import logging
- from collections.abc import AsyncIterator, Iterator, Sequence
- from dataclasses import asdict
- from typing import (
- Any,
- Literal,
- cast,
- )
- from uuid import UUID
- import langsmith as ls
- from langchain_core.runnables import RunnableConfig
- from langchain_core.runnables.graph import (
- Edge as DrawableEdge,
- )
- from langchain_core.runnables.graph import (
- Graph as DrawableGraph,
- )
- from langchain_core.runnables.graph import (
- Node as DrawableNode,
- )
- from langgraph.checkpoint.base import CheckpointMetadata
- from langgraph_sdk.client import (
- LangGraphClient,
- SyncLangGraphClient,
- get_client,
- get_sync_client,
- )
- from langgraph_sdk.schema import (
- Checkpoint,
- QueryParamTypes,
- ThreadState,
- )
- from langgraph_sdk.schema import (
- Command as CommandSDK,
- )
- from langgraph_sdk.schema import (
- StreamMode as StreamModeSDK,
- )
- from typing_extensions import Self
- from langgraph._internal._config import merge_configs
- from langgraph._internal._constants import (
- CONF,
- CONFIG_KEY_CHECKPOINT_ID,
- CONFIG_KEY_CHECKPOINT_MAP,
- CONFIG_KEY_CHECKPOINT_NS,
- CONFIG_KEY_STREAM,
- CONFIG_KEY_TASK_ID,
- INTERRUPT,
- NS_SEP,
- )
- from langgraph.errors import GraphInterrupt, ParentCommand
- from langgraph.pregel.protocol import PregelProtocol, StreamProtocol
- from langgraph.types import (
- All,
- Command,
- Interrupt,
- PregelTask,
- StateSnapshot,
- StreamMode,
- )
- logger = logging.getLogger(__name__)
- __all__ = ("RemoteGraph", "RemoteException")
- _CONF_DROPLIST = frozenset(
- (
- CONFIG_KEY_CHECKPOINT_MAP,
- CONFIG_KEY_CHECKPOINT_ID,
- CONFIG_KEY_CHECKPOINT_NS,
- CONFIG_KEY_TASK_ID,
- ),
- )
- def _sanitize_config_value(v: Any) -> Any:
- """Recursively sanitize a config value to ensure it contains only primitives."""
- if isinstance(v, (str, int, float, bool, UUID)):
- return v
- elif isinstance(v, dict):
- sanitized_dict = {}
- for k, val in v.items():
- if isinstance(k, str):
- sanitized_value = _sanitize_config_value(val)
- if sanitized_value is not None:
- sanitized_dict[k] = sanitized_value
- return sanitized_dict
- elif isinstance(v, (list, tuple)):
- sanitized_list = []
- for item in v:
- sanitized_item = _sanitize_config_value(item)
- if sanitized_item is not None:
- sanitized_list.append(sanitized_item)
- return sanitized_list
- return None
- class RemoteException(Exception):
- """Exception raised when an error occurs in the remote graph."""
- pass
- class RemoteGraph(PregelProtocol):
- """The `RemoteGraph` class is a client implementation for calling remote
- APIs that implement the LangGraph Server API specification.
- For example, the `RemoteGraph` class can be used to call APIs from deployments
- on LangSmith Deployment.
- `RemoteGraph` behaves the same way as a `Graph` and can be used directly as
- a node in another `Graph`.
- """
- assistant_id: str
- name: str | None
- def __init__(
- self,
- assistant_id: str, # graph_id
- /,
- *,
- url: str | None = None,
- api_key: str | None = None,
- headers: dict[str, str] | None = None,
- client: LangGraphClient | None = None,
- sync_client: SyncLangGraphClient | None = None,
- config: RunnableConfig | None = None,
- name: str | None = None,
- distributed_tracing: bool = False,
- ):
- """Specify `url`, `api_key`, and/or `headers` to create default sync and async clients.
- If `client` or `sync_client` are provided, they will be used instead of the default clients.
- See `LangGraphClient` and `SyncLangGraphClient` for details on the default clients. At least
- one of `url`, `client`, or `sync_client` must be provided.
- Args:
- assistant_id: The assistant ID or graph name of the remote graph to use.
- url: The URL of the remote API.
- api_key: The API key to use for authentication. If not provided, it will be read from the environment (`LANGGRAPH_API_KEY`, `LANGSMITH_API_KEY`, or `LANGCHAIN_API_KEY`).
- headers: Additional headers to include in the requests.
- client: A `LangGraphClient` instance to use instead of creating a default client.
- sync_client: A `SyncLangGraphClient` instance to use instead of creating a default client.
- config: An optional `RunnableConfig` instance with additional configuration.
- name: Human-readable name to attach to the RemoteGraph instance.
- This is useful for adding `RemoteGraph` as a subgraph via `graph.add_node(remote_graph)`.
- If not provided, defaults to the assistant ID.
- distributed_tracing: Whether to enable sending LangSmith distributed tracing headers.
- """
- self.assistant_id = assistant_id
- if name is None:
- self.name = assistant_id
- else:
- self.name = name
- self.config = config
- self.distributed_tracing = distributed_tracing
- if client is None and url is not None:
- client = get_client(url=url, api_key=api_key, headers=headers)
- self.client = client
- if sync_client is None and url is not None:
- sync_client = get_sync_client(url=url, api_key=api_key, headers=headers)
- self.sync_client = sync_client
- def _validate_client(self) -> LangGraphClient:
- if self.client is None:
- raise ValueError(
- "Async client is not initialized: please provide `url` or `client` when initializing `RemoteGraph`."
- )
- return self.client
- def _validate_sync_client(self) -> SyncLangGraphClient:
- if self.sync_client is None:
- raise ValueError(
- "Sync client is not initialized: please provide `url` or `sync_client` when initializing `RemoteGraph`."
- )
- return self.sync_client
- def copy(self, update: dict[str, Any]) -> Self:
- attrs = {**self.__dict__, **update}
- return self.__class__(attrs.pop("assistant_id"), **attrs)
- def with_config(self, config: RunnableConfig | None = None, **kwargs: Any) -> Self:
- return self.copy(
- {"config": merge_configs(self.config, config, cast(RunnableConfig, kwargs))}
- )
- def _get_drawable_nodes(
- self, graph: dict[str, list[dict[str, Any]]]
- ) -> dict[str, DrawableNode]:
- nodes = {}
- for node in graph["nodes"]:
- node_id = str(node["id"])
- node_data = node.get("data", {})
- # Get node name from node_data if available. If not, use node_id.
- node_name = node.get("name")
- if node_name is None:
- if isinstance(node_data, dict):
- node_name = node_data.get("name", node_id)
- else:
- node_name = node_id
- nodes[node_id] = DrawableNode(
- id=node_id,
- name=node_name,
- data=node_data,
- metadata=node.get("metadata"),
- )
- return nodes
- def get_graph(
- self,
- config: RunnableConfig | None = None,
- *,
- xray: int | bool = False,
- headers: dict[str, str] | None = None,
- params: QueryParamTypes | None = None,
- ) -> DrawableGraph:
- """Get graph by graph name.
- This method calls `GET /assistants/{assistant_id}/graph`.
- Args:
- config: This parameter is not used.
- xray: Include graph representation of subgraphs. If an integer
- value is provided, only subgraphs with a depth less than or
- equal to the value will be included.
- Returns:
- The graph information for the assistant in JSON format.
- """
- sync_client = self._validate_sync_client()
- graph = sync_client.assistants.get_graph(
- assistant_id=self.assistant_id,
- xray=xray,
- headers=headers,
- params=params,
- )
- return DrawableGraph(
- nodes=self._get_drawable_nodes(graph),
- edges=[DrawableEdge(**edge) for edge in graph["edges"]],
- )
- async def aget_graph(
- self,
- config: RunnableConfig | None = None,
- *,
- xray: int | bool = False,
- headers: dict[str, str] | None = None,
- params: QueryParamTypes | None = None,
- ) -> DrawableGraph:
- """Get graph by graph name.
- This method calls `GET /assistants/{assistant_id}/graph`.
- Args:
- config: This parameter is not used.
- xray: Include graph representation of subgraphs. If an integer
- value is provided, only subgraphs with a depth less than or
- equal to the value will be included.
- Returns:
- The graph information for the assistant in JSON format.
- """
- client = self._validate_client()
- graph = await client.assistants.get_graph(
- assistant_id=self.assistant_id,
- xray=xray,
- headers=headers,
- params=params,
- )
- return DrawableGraph(
- nodes=self._get_drawable_nodes(graph),
- edges=[DrawableEdge(**edge) for edge in graph["edges"]],
- )
- def _create_state_snapshot(self, state: ThreadState) -> StateSnapshot:
- tasks: list[PregelTask] = []
- for task in state["tasks"]:
- interrupts = tuple(
- Interrupt(**interrupt) for interrupt in task["interrupts"]
- )
- tasks.append(
- PregelTask(
- id=task["id"],
- name=task["name"],
- path=tuple(),
- error=Exception(task["error"]) if task["error"] else None,
- interrupts=interrupts,
- state=(
- self._create_state_snapshot(task["state"])
- if task["state"]
- else (
- cast(RunnableConfig, {"configurable": task["checkpoint"]})
- if task["checkpoint"]
- else None
- )
- ),
- result=task.get("result"),
- )
- )
- return StateSnapshot(
- values=state["values"],
- next=tuple(state["next"]) if state["next"] else tuple(),
- config={
- "configurable": {
- "thread_id": state["checkpoint"]["thread_id"],
- "checkpoint_ns": state["checkpoint"]["checkpoint_ns"],
- "checkpoint_id": state["checkpoint"]["checkpoint_id"],
- "checkpoint_map": state["checkpoint"].get("checkpoint_map", {}),
- }
- },
- metadata=CheckpointMetadata(**state["metadata"]),
- created_at=state["created_at"],
- parent_config=(
- {
- "configurable": {
- "thread_id": state["parent_checkpoint"]["thread_id"],
- "checkpoint_ns": state["parent_checkpoint"]["checkpoint_ns"],
- "checkpoint_id": state["parent_checkpoint"]["checkpoint_id"],
- "checkpoint_map": state["parent_checkpoint"].get(
- "checkpoint_map", {}
- ),
- }
- }
- if state["parent_checkpoint"]
- else None
- ),
- tasks=tuple(tasks),
- interrupts=tuple([i for task in tasks for i in task.interrupts]),
- )
- def _get_checkpoint(self, config: RunnableConfig | None) -> Checkpoint | None:
- if config is None:
- return None
- checkpoint = {}
- if "thread_id" in config["configurable"]:
- checkpoint["thread_id"] = config["configurable"]["thread_id"]
- if "checkpoint_ns" in config["configurable"]:
- checkpoint["checkpoint_ns"] = config["configurable"]["checkpoint_ns"]
- if "checkpoint_id" in config["configurable"]:
- checkpoint["checkpoint_id"] = config["configurable"]["checkpoint_id"]
- if "checkpoint_map" in config["configurable"]:
- checkpoint["checkpoint_map"] = config["configurable"]["checkpoint_map"]
- return checkpoint if checkpoint else None
- def _get_config(self, checkpoint: Checkpoint) -> RunnableConfig:
- return {
- "configurable": {
- "thread_id": checkpoint["thread_id"],
- "checkpoint_ns": checkpoint["checkpoint_ns"],
- "checkpoint_id": checkpoint["checkpoint_id"],
- "checkpoint_map": checkpoint.get("checkpoint_map", {}),
- }
- }
- def _sanitize_config(self, config: RunnableConfig) -> RunnableConfig:
- """Sanitize the config to remove non-serializable fields."""
- sanitized: RunnableConfig = {}
- if "recursion_limit" in config:
- sanitized["recursion_limit"] = config["recursion_limit"]
- if "tags" in config:
- sanitized["tags"] = [tag for tag in config["tags"] if isinstance(tag, str)]
- if "metadata" in config:
- sanitized["metadata"] = {}
- for k, v in config["metadata"].items():
- if (
- isinstance(k, str)
- and (sanitized_value := _sanitize_config_value(v)) is not None
- ):
- sanitized["metadata"][k] = sanitized_value
- if "configurable" in config:
- sanitized["configurable"] = {}
- for k, v in config["configurable"].items():
- if (
- isinstance(k, str)
- and k not in _CONF_DROPLIST
- and (sanitized_value := _sanitize_config_value(v)) is not None
- ):
- sanitized["configurable"][k] = sanitized_value
- return sanitized
- def get_state(
- self,
- config: RunnableConfig,
- *,
- subgraphs: bool = False,
- headers: dict[str, str] | None = None,
- params: QueryParamTypes | None = None,
- ) -> StateSnapshot:
- """Get the state of a thread.
- This method calls `POST /threads/{thread_id}/state/checkpoint` if a
- checkpoint is specified in the config or `GET /threads/{thread_id}/state`
- if no checkpoint is specified.
- Args:
- config: A `RunnableConfig` that includes `thread_id` in the
- `configurable` field.
- subgraphs: Include subgraphs in the state.
- headers: Optional custom headers to include with the request.
- params: Optional query parameters to include with the request.
- Returns:
- The latest state of the thread.
- """
- sync_client = self._validate_sync_client()
- merged_config = merge_configs(self.config, config)
- state = sync_client.threads.get_state(
- thread_id=merged_config["configurable"]["thread_id"],
- checkpoint=self._get_checkpoint(merged_config),
- subgraphs=subgraphs,
- headers=headers,
- params=params,
- )
- return self._create_state_snapshot(state)
- async def aget_state(
- self,
- config: RunnableConfig,
- *,
- subgraphs: bool = False,
- headers: dict[str, str] | None = None,
- params: QueryParamTypes | None = None,
- ) -> StateSnapshot:
- """Get the state of a thread.
- This method calls `POST /threads/{thread_id}/state/checkpoint` if a
- checkpoint is specified in the config or `GET /threads/{thread_id}/state`
- if no checkpoint is specified.
- Args:
- config: A `RunnableConfig` that includes `thread_id` in the
- `configurable` field.
- subgraphs: Include subgraphs in the state.
- headers: Optional custom headers to include with the request.
- params: Optional query parameters to include with the request.
- Returns:
- The latest state of the thread.
- """
- client = self._validate_client()
- merged_config = merge_configs(self.config, config)
- state = await client.threads.get_state(
- thread_id=merged_config["configurable"]["thread_id"],
- checkpoint=self._get_checkpoint(merged_config),
- subgraphs=subgraphs,
- headers=headers,
- params=params,
- )
- return self._create_state_snapshot(state)
- def get_state_history(
- self,
- config: RunnableConfig,
- *,
- filter: dict[str, Any] | None = None,
- before: RunnableConfig | None = None,
- limit: int | None = None,
- headers: dict[str, str] | None = None,
- params: QueryParamTypes | None = None,
- ) -> Iterator[StateSnapshot]:
- """Get the state history of a thread.
- This method calls `POST /threads/{thread_id}/history`.
- Args:
- config: A `RunnableConfig` that includes `thread_id` in the
- `configurable` field.
- filter: Metadata to filter on.
- before: A `RunnableConfig` that includes checkpoint metadata.
- limit: Max number of states to return.
- Returns:
- States of the thread.
- """
- sync_client = self._validate_sync_client()
- merged_config = merge_configs(self.config, config)
- states = sync_client.threads.get_history(
- thread_id=merged_config["configurable"]["thread_id"],
- limit=limit if limit else 10,
- before=self._get_checkpoint(before),
- metadata=filter,
- checkpoint=self._get_checkpoint(merged_config),
- headers=headers,
- params=params,
- )
- for state in states:
- yield self._create_state_snapshot(state)
- async def aget_state_history(
- self,
- config: RunnableConfig,
- *,
- filter: dict[str, Any] | None = None,
- before: RunnableConfig | None = None,
- limit: int | None = None,
- headers: dict[str, str] | None = None,
- params: QueryParamTypes | None = None,
- ) -> AsyncIterator[StateSnapshot]:
- """Get the state history of a thread.
- This method calls `POST /threads/{thread_id}/history`.
- Args:
- config: A `RunnableConfig` that includes `thread_id` in the
- `configurable` field.
- filter: Metadata to filter on.
- before: A `RunnableConfig` that includes checkpoint metadata.
- limit: Max number of states to return.
- headers: Optional custom headers to include with the request.
- params: Optional query parameters to include with the request.
- Returns:
- States of the thread.
- """
- client = self._validate_client()
- merged_config = merge_configs(self.config, config)
- states = await client.threads.get_history(
- thread_id=merged_config["configurable"]["thread_id"],
- limit=limit if limit else 10,
- before=self._get_checkpoint(before),
- metadata=filter,
- checkpoint=self._get_checkpoint(merged_config),
- headers=headers,
- params=params,
- )
- for state in states:
- yield self._create_state_snapshot(state)
- def bulk_update_state(
- self,
- config: RunnableConfig,
- updates: list[tuple[dict[str, Any] | None, str | None]],
- ) -> RunnableConfig:
- raise NotImplementedError
- async def abulk_update_state(
- self,
- config: RunnableConfig,
- updates: list[tuple[dict[str, Any] | None, str | None]],
- ) -> RunnableConfig:
- raise NotImplementedError
- def update_state(
- self,
- config: RunnableConfig,
- values: dict[str, Any] | Any | None,
- as_node: str | None = None,
- *,
- headers: dict[str, str] | None = None,
- params: QueryParamTypes | None = None,
- ) -> RunnableConfig:
- """Update the state of a thread.
- This method calls `POST /threads/{thread_id}/state`.
- Args:
- config: A `RunnableConfig` that includes `thread_id` in the
- `configurable` field.
- values: Values to update to the state.
- as_node: Update the state as if this node had just executed.
- Returns:
- `RunnableConfig` for the updated thread.
- """
- sync_client = self._validate_sync_client()
- merged_config = merge_configs(self.config, config)
- response: dict = sync_client.threads.update_state( # type: ignore
- thread_id=merged_config["configurable"]["thread_id"],
- values=values,
- as_node=as_node,
- checkpoint=self._get_checkpoint(merged_config),
- headers=headers,
- params=params,
- )
- return self._get_config(response["checkpoint"])
- async def aupdate_state(
- self,
- config: RunnableConfig,
- values: dict[str, Any] | Any | None,
- as_node: str | None = None,
- *,
- headers: dict[str, str] | None = None,
- params: QueryParamTypes | None = None,
- ) -> RunnableConfig:
- """Update the state of a thread.
- This method calls `POST /threads/{thread_id}/state`.
- Args:
- config: A `RunnableConfig` that includes `thread_id` in the
- `configurable` field.
- values: Values to update to the state.
- as_node: Update the state as if this node had just executed.
- Returns:
- `RunnableConfig` for the updated thread.
- """
- client = self._validate_client()
- merged_config = merge_configs(self.config, config)
- response: dict = await client.threads.update_state( # type: ignore
- thread_id=merged_config["configurable"]["thread_id"],
- values=values,
- as_node=as_node,
- checkpoint=self._get_checkpoint(merged_config),
- headers=headers,
- params=params,
- )
- return self._get_config(response["checkpoint"])
- def _get_stream_modes(
- self,
- stream_mode: StreamMode | list[StreamMode] | None,
- config: RunnableConfig | None,
- default: StreamMode = "updates",
- ) -> tuple[list[StreamModeSDK], list[StreamModeSDK], bool, StreamProtocol | None]:
- """Return a tuple of the final list of stream modes sent to the
- remote graph and a boolean flag indicating if stream mode 'updates'
- was present in the original list of stream modes.
- 'updates' mode is added to the list of stream modes so that interrupts
- can be detected in the remote graph.
- """
- updated_stream_modes: list[StreamModeSDK] = []
- req_single = True
- # coerce to list, or add default stream mode
- if stream_mode:
- if isinstance(stream_mode, str):
- updated_stream_modes.append(stream_mode)
- else:
- req_single = False
- updated_stream_modes.extend(stream_mode)
- else:
- updated_stream_modes.append(default)
- requested_stream_modes = updated_stream_modes.copy()
- # add any from parent graph
- stream: StreamProtocol | None = (
- (config or {}).get(CONF, {}).get(CONFIG_KEY_STREAM)
- )
- if stream:
- updated_stream_modes.extend(stream.modes)
- # map "messages" to "messages-tuple"
- if "messages" in updated_stream_modes:
- updated_stream_modes.remove("messages")
- updated_stream_modes.append("messages-tuple")
- # if requested "messages-tuple",
- # map to "messages" in requested_stream_modes
- if "messages-tuple" in requested_stream_modes:
- requested_stream_modes.remove("messages-tuple")
- requested_stream_modes.append("messages")
- # add 'updates' mode if not present
- if "updates" not in updated_stream_modes:
- updated_stream_modes.append("updates")
- # remove 'events', as it's not supported in Pregel
- if "events" in updated_stream_modes:
- updated_stream_modes.remove("events")
- return (updated_stream_modes, requested_stream_modes, req_single, stream)
- def stream(
- self,
- input: dict[str, Any] | Any,
- config: RunnableConfig | None = None,
- *,
- stream_mode: StreamMode | list[StreamMode] | None = None,
- interrupt_before: All | Sequence[str] | None = None,
- interrupt_after: All | Sequence[str] | None = None,
- subgraphs: bool = False,
- headers: dict[str, str] | None = None,
- params: QueryParamTypes | None = None,
- **kwargs: Any,
- ) -> Iterator[dict[str, Any] | Any]:
- """Create a run and stream the results.
- This method calls `POST /threads/{thread_id}/runs/stream` if a `thread_id`
- is speciffed in the `configurable` field of the config or
- `POST /runs/stream` otherwise.
- Args:
- input: Input to the graph.
- config: A `RunnableConfig` for graph invocation.
- stream_mode: Stream mode(s) to use.
- interrupt_before: Interrupt the graph before these nodes.
- interrupt_after: Interrupt the graph after these nodes.
- subgraphs: Stream from subgraphs.
- headers: Additional headers to pass to the request.
- **kwargs: Additional params to pass to client.runs.stream.
- Yields:
- The output of the graph.
- """
- sync_client = self._validate_sync_client()
- merged_config = merge_configs(self.config, config)
- sanitized_config = self._sanitize_config(merged_config)
- stream_modes, requested, req_single, stream = self._get_stream_modes(
- stream_mode, config
- )
- if isinstance(input, Command):
- command: CommandSDK | None = cast(CommandSDK, asdict(input))
- input = None
- else:
- command = None
- thread_id = sanitized_config.get("configurable", {}).pop("thread_id", None)
- for chunk in sync_client.runs.stream(
- thread_id=thread_id,
- assistant_id=self.assistant_id,
- input=input,
- command=command,
- config=sanitized_config,
- stream_mode=stream_modes,
- interrupt_before=interrupt_before,
- interrupt_after=interrupt_after,
- stream_subgraphs=subgraphs or stream is not None,
- if_not_exists="create",
- headers=(
- _merge_tracing_headers(headers) if self.distributed_tracing else headers
- ),
- params=params,
- **kwargs,
- ):
- # split mode and ns
- if NS_SEP in chunk.event:
- mode, ns_ = chunk.event.split(NS_SEP, 1)
- ns = tuple(ns_.split(NS_SEP))
- else:
- mode, ns = chunk.event, ()
- # raise ParentCommand exception for command events
- if mode == "command" and chunk.data.get("graph") == Command.PARENT:
- raise ParentCommand(Command(**chunk.data))
- # prepend caller ns (as it is not passed to remote graph)
- if caller_ns := (config or {}).get(CONF, {}).get(CONFIG_KEY_CHECKPOINT_NS):
- caller_ns = tuple(caller_ns.split(NS_SEP))
- ns = caller_ns + ns
- # stream to parent stream
- if stream is not None and mode in stream.modes:
- stream((ns, mode, chunk.data))
- # raise interrupt or errors
- if chunk.event.startswith("updates"):
- if isinstance(chunk.data, dict) and INTERRUPT in chunk.data:
- if caller_ns:
- raise GraphInterrupt(
- [Interrupt(**i) for i in chunk.data[INTERRUPT]]
- )
- elif chunk.event.startswith("error"):
- raise RemoteException(chunk.data)
- # filter for what was actually requested
- if mode not in requested:
- continue
- if chunk.event.startswith("messages"):
- chunk = chunk._replace(data=tuple(chunk.data)) # type: ignore
- # emit chunk
- if subgraphs:
- if NS_SEP in chunk.event:
- mode, ns_ = chunk.event.split(NS_SEP, 1)
- ns = tuple(ns_.split(NS_SEP))
- else:
- mode, ns = chunk.event, ()
- if req_single:
- yield ns, chunk.data
- else:
- yield ns, mode, chunk.data
- elif req_single:
- yield chunk.data
- else:
- yield chunk
- async def astream(
- self,
- input: dict[str, Any] | Any,
- config: RunnableConfig | None = None,
- *,
- stream_mode: StreamMode | list[StreamMode] | None = None,
- interrupt_before: All | Sequence[str] | None = None,
- interrupt_after: All | Sequence[str] | None = None,
- subgraphs: bool = False,
- headers: dict[str, str] | None = None,
- params: QueryParamTypes | None = None,
- **kwargs: Any,
- ) -> AsyncIterator[dict[str, Any] | Any]:
- """Create a run and stream the results.
- This method calls `POST /threads/{thread_id}/runs/stream` if a `thread_id`
- is speciffed in the `configurable` field of the config or
- `POST /runs/stream` otherwise.
- Args:
- input: Input to the graph.
- config: A `RunnableConfig` for graph invocation.
- stream_mode: Stream mode(s) to use.
- interrupt_before: Interrupt the graph before these nodes.
- interrupt_after: Interrupt the graph after these nodes.
- subgraphs: Stream from subgraphs.
- headers: Additional headers to pass to the request.
- **kwargs: Additional params to pass to client.runs.stream.
- Yields:
- The output of the graph.
- """
- client = self._validate_client()
- merged_config = merge_configs(self.config, config)
- sanitized_config = self._sanitize_config(merged_config)
- stream_modes, requested, req_single, stream = self._get_stream_modes(
- stream_mode, config
- )
- if isinstance(input, Command):
- command: CommandSDK | None = cast(CommandSDK, asdict(input))
- input = None
- else:
- command = None
- thread_id = sanitized_config.get("configurable", {}).pop("thread_id", None)
- async for chunk in client.runs.stream(
- thread_id=thread_id,
- assistant_id=self.assistant_id,
- input=input,
- command=command,
- config=sanitized_config,
- stream_mode=stream_modes,
- interrupt_before=interrupt_before,
- interrupt_after=interrupt_after,
- stream_subgraphs=subgraphs or stream is not None,
- if_not_exists="create",
- headers=(
- _merge_tracing_headers(headers) if self.distributed_tracing else headers
- ),
- params=params,
- **kwargs,
- ):
- # split mode and ns
- if NS_SEP in chunk.event:
- mode, ns_ = chunk.event.split(NS_SEP, 1)
- ns = tuple(ns_.split(NS_SEP))
- else:
- mode, ns = chunk.event, ()
- # raise ParentCommand exception for command events
- if mode == "command" and chunk.data.get("graph") == Command.PARENT:
- raise ParentCommand(Command(**chunk.data))
- # prepend caller ns (as it is not passed to remote graph)
- if caller_ns := (config or {}).get(CONF, {}).get(CONFIG_KEY_CHECKPOINT_NS):
- caller_ns = tuple(caller_ns.split(NS_SEP))
- ns = caller_ns + ns
- # stream to parent stream
- if stream is not None and mode in stream.modes:
- stream((ns, mode, chunk.data))
- # raise interrupt or errors
- if chunk.event.startswith("updates"):
- if isinstance(chunk.data, dict) and INTERRUPT in chunk.data:
- if caller_ns:
- raise GraphInterrupt(
- [Interrupt(**i) for i in chunk.data[INTERRUPT]]
- )
- elif chunk.event.startswith("error"):
- raise RemoteException(chunk.data)
- # filter for what was actually requested
- if mode not in requested:
- continue
- if chunk.event.startswith("messages"):
- chunk = chunk._replace(data=tuple(chunk.data)) # type: ignore
- # emit chunk
- if subgraphs:
- if NS_SEP in chunk.event:
- mode, ns_ = chunk.event.split(NS_SEP, 1)
- ns = tuple(ns_.split(NS_SEP))
- else:
- mode, ns = chunk.event, ()
- if req_single:
- yield ns, chunk.data
- else:
- yield ns, mode, chunk.data
- elif req_single:
- yield chunk.data
- else:
- yield chunk
- async def astream_events(
- self,
- input: Any,
- config: RunnableConfig | None = None,
- *,
- version: Literal["v1", "v2"],
- include_names: Sequence[All] | None = None,
- include_types: Sequence[All] | None = None,
- include_tags: Sequence[All] | None = None,
- exclude_names: Sequence[All] | None = None,
- exclude_types: Sequence[All] | None = None,
- exclude_tags: Sequence[All] | None = None,
- **kwargs: Any,
- ) -> AsyncIterator[dict[str, Any]]:
- raise NotImplementedError
- def invoke(
- self,
- input: dict[str, Any] | Any,
- config: RunnableConfig | None = None,
- *,
- interrupt_before: All | Sequence[str] | None = None,
- interrupt_after: All | Sequence[str] | None = None,
- headers: dict[str, str] | None = None,
- params: QueryParamTypes | None = None,
- **kwargs: Any,
- ) -> dict[str, Any] | Any:
- """Create a run, wait until it finishes and return the final state.
- Args:
- input: Input to the graph.
- config: A `RunnableConfig` for graph invocation.
- interrupt_before: Interrupt the graph before these nodes.
- interrupt_after: Interrupt the graph after these nodes.
- headers: Additional headers to pass to the request.
- **kwargs: Additional params to pass to RemoteGraph.stream.
- Returns:
- The output of the graph.
- """
- for chunk in self.stream(
- input,
- config=config,
- interrupt_before=interrupt_before,
- interrupt_after=interrupt_after,
- headers=headers,
- stream_mode="values",
- params=params,
- **kwargs,
- ):
- pass
- try:
- return chunk
- except UnboundLocalError:
- logger.warning("No events received from remote graph")
- return None
- async def ainvoke(
- self,
- input: dict[str, Any] | Any,
- config: RunnableConfig | None = None,
- *,
- interrupt_before: All | Sequence[str] | None = None,
- interrupt_after: All | Sequence[str] | None = None,
- headers: dict[str, str] | None = None,
- params: QueryParamTypes | None = None,
- **kwargs: Any,
- ) -> dict[str, Any] | Any:
- """Create a run, wait until it finishes and return the final state.
- Args:
- input: Input to the graph.
- config: A `RunnableConfig` for graph invocation.
- interrupt_before: Interrupt the graph before these nodes.
- interrupt_after: Interrupt the graph after these nodes.
- headers: Additional headers to pass to the request.
- **kwargs: Additional params to pass to RemoteGraph.astream.
- Returns:
- The output of the graph.
- """
- async for chunk in self.astream(
- input,
- config=config,
- interrupt_before=interrupt_before,
- interrupt_after=interrupt_after,
- headers=headers,
- stream_mode="values",
- params=params,
- **kwargs,
- ):
- pass
- try:
- return chunk
- except UnboundLocalError:
- logger.warning("No events received from remote graph")
- return None
- def _merge_tracing_headers(headers: dict[str, str] | None) -> dict[str, str] | None:
- if rt := ls.get_current_run_tree():
- tracing_headers = rt.to_headers()
- if headers:
- if "baggage" in headers:
- tracing_headers["baggage"] = (
- f"{headers['baggage']},{tracing_headers['baggage']}"
- )
- headers.update(tracing_headers)
- else:
- headers = tracing_headers
- return headers
|