| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250 |
- from __future__ import annotations
- from collections.abc import AsyncIterator, Callable, Iterator, Sequence
- from typing import (
- Any,
- TypeVar,
- cast,
- )
- from uuid import UUID, uuid4
- from langchain_core.callbacks import BaseCallbackHandler
- from langchain_core.messages import BaseMessage
- from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
- from langgraph._internal._constants import NS_SEP
- from langgraph.constants import TAG_HIDDEN, TAG_NOSTREAM
- from langgraph.pregel.protocol import StreamChunk
- from langgraph.types import Command
- try:
- from langchain_core.tracers._streaming import _StreamingCallbackHandler
- except ImportError:
- _StreamingCallbackHandler = object # type: ignore
- T = TypeVar("T")
- Meta = tuple[tuple[str, ...], dict[str, Any]]
- class StreamMessagesHandler(BaseCallbackHandler, _StreamingCallbackHandler):
- """A callback handler that implements stream_mode=messages.
- Collects messages from:
- (1) chat model stream events; and
- (2) node outputs.
- """
- run_inline = True
- """We want this callback to run in the main thread to avoid order/locking issues."""
- def __init__(
- self,
- stream: Callable[[StreamChunk], None],
- subgraphs: bool,
- *,
- parent_ns: tuple[str, ...] | None = None,
- ) -> None:
- """Configure the handler to stream messages from LLMs and nodes.
- Args:
- stream: A callable that takes a StreamChunk and emits it.
- subgraphs: Whether to emit messages from subgraphs.
- parent_ns: The namespace where the handler was created.
- We keep track of this namespace to allow calls to subgraphs that
- were explicitly requested as a stream with `messages` mode
- configured.
- Example:
- parent_ns is used to handle scenarios where the subgraph is explicitly
- streamed with `stream_mode="messages"`.
- ```python
- def parent_graph_node():
- # This node is in the parent graph.
- async for event in some_subgraph(..., stream_mode="messages"):
- do something with event # <-- these events will be emitted
- return ...
- parent_graph.invoke(subgraphs=False)
- ```
- """
- self.stream = stream
- self.subgraphs = subgraphs
- self.metadata: dict[UUID, Meta] = {}
- self.seen: set[int | str] = set()
- self.parent_ns = parent_ns
- def _emit(self, meta: Meta, message: BaseMessage, *, dedupe: bool = False) -> None:
- if dedupe and message.id in self.seen:
- return
- else:
- if message.id is None:
- message.id = str(uuid4())
- self.seen.add(message.id)
- self.stream((meta[0], "messages", (message, meta[1])))
- def _find_and_emit_messages(self, meta: Meta, response: Any) -> None:
- if isinstance(response, BaseMessage):
- self._emit(meta, response, dedupe=True)
- elif isinstance(response, Sequence):
- for value in response:
- if isinstance(value, BaseMessage):
- self._emit(meta, value, dedupe=True)
- elif isinstance(response, dict):
- for value in response.values():
- if isinstance(value, BaseMessage):
- self._emit(meta, value, dedupe=True)
- elif isinstance(value, Sequence):
- for item in value:
- if isinstance(item, BaseMessage):
- self._emit(meta, item, dedupe=True)
- elif hasattr(response, "__dir__") and callable(response.__dir__):
- for key in dir(response):
- try:
- value = getattr(response, key)
- if isinstance(value, BaseMessage):
- self._emit(meta, value, dedupe=True)
- elif isinstance(value, Sequence):
- for item in value:
- if isinstance(item, BaseMessage):
- self._emit(meta, item, dedupe=True)
- except AttributeError:
- pass
- def tap_output_aiter(
- self, run_id: UUID, output: AsyncIterator[T]
- ) -> AsyncIterator[T]:
- return output
- def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
- return output
- def on_chat_model_start(
- self,
- serialized: dict[str, Any],
- messages: list[list[BaseMessage]],
- *,
- run_id: UUID,
- parent_run_id: UUID | None = None,
- tags: list[str] | None = None,
- metadata: dict[str, Any] | None = None,
- **kwargs: Any,
- ) -> Any:
- if metadata and (not tags or (TAG_NOSTREAM not in tags)):
- ns = tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP))[
- :-1
- ]
- if not self.subgraphs and len(ns) > 0 and ns != self.parent_ns:
- return
- if tags:
- if filtered_tags := [t for t in tags if not t.startswith("seq:step")]:
- metadata["tags"] = filtered_tags
- self.metadata[run_id] = (ns, metadata)
- def on_llm_new_token(
- self,
- token: str,
- *,
- chunk: ChatGenerationChunk | None = None,
- run_id: UUID,
- parent_run_id: UUID | None = None,
- tags: list[str] | None = None,
- **kwargs: Any,
- ) -> Any:
- if not isinstance(chunk, ChatGenerationChunk):
- return
- if meta := self.metadata.get(run_id):
- self._emit(meta, chunk.message)
- def on_llm_end(
- self,
- response: LLMResult,
- *,
- run_id: UUID,
- parent_run_id: UUID | None = None,
- **kwargs: Any,
- ) -> Any:
- if meta := self.metadata.get(run_id):
- if response.generations and response.generations[0]:
- gen = response.generations[0][0]
- if isinstance(gen, ChatGeneration):
- self._emit(meta, gen.message, dedupe=True)
- self.metadata.pop(run_id, None)
- def on_llm_error(
- self,
- error: BaseException,
- *,
- run_id: UUID,
- parent_run_id: UUID | None = None,
- **kwargs: Any,
- ) -> Any:
- self.metadata.pop(run_id, None)
- def on_chain_start(
- self,
- serialized: dict[str, Any],
- inputs: dict[str, Any],
- *,
- run_id: UUID,
- parent_run_id: UUID | None = None,
- tags: list[str] | None = None,
- metadata: dict[str, Any] | None = None,
- **kwargs: Any,
- ) -> Any:
- if (
- metadata
- and kwargs.get("name") == metadata.get("langgraph_node")
- and (not tags or TAG_HIDDEN not in tags)
- ):
- ns = tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP))[
- :-1
- ]
- if not self.subgraphs and len(ns) > 0:
- return
- self.metadata[run_id] = (ns, metadata)
- if isinstance(inputs, dict):
- for key, value in inputs.items():
- if isinstance(value, BaseMessage):
- if value.id is not None:
- self.seen.add(value.id)
- elif isinstance(value, Sequence) and not isinstance(value, str):
- for item in value:
- if isinstance(item, BaseMessage):
- if item.id is not None:
- self.seen.add(item.id)
- def on_chain_end(
- self,
- response: Any,
- *,
- run_id: UUID,
- parent_run_id: UUID | None = None,
- **kwargs: Any,
- ) -> Any:
- if meta := self.metadata.pop(run_id, None):
- # Handle Command node updates
- if isinstance(response, Command):
- self._find_and_emit_messages(meta, response.update)
- # Handle list of Command updates
- elif isinstance(response, Sequence) and any(
- isinstance(value, Command) for value in response
- ):
- for value in response:
- if isinstance(value, Command):
- self._find_and_emit_messages(meta, value.update)
- else:
- self._find_and_emit_messages(meta, value)
- # Handle basic updates / streaming
- else:
- self._find_and_emit_messages(meta, response)
- def on_chain_error(
- self,
- error: BaseException,
- *,
- run_id: UUID,
- parent_run_id: UUID | None = None,
- **kwargs: Any,
- ) -> Any:
- self.metadata.pop(run_id, None)
|