| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372 |
- from __future__ import annotations
- import uuid
- import warnings
- from collections.abc import Callable, Sequence
- from functools import partial
- from typing import (
- Annotated,
- Any,
- Literal,
- cast,
- )
- from langchain_core.messages import (
- AnyMessage,
- BaseMessage,
- BaseMessageChunk,
- MessageLikeRepresentation,
- RemoveMessage,
- convert_to_messages,
- message_chunk_to_message,
- )
- from typing_extensions import TypedDict, deprecated
- from langgraph._internal._constants import CONF, CONFIG_KEY_SEND, NS_SEP
- from langgraph.graph.state import StateGraph
- from langgraph.warnings import LangGraphDeprecatedSinceV10
- __all__ = (
- "add_messages",
- "MessagesState",
- "MessageGraph",
- "REMOVE_ALL_MESSAGES",
- )
- Messages = list[MessageLikeRepresentation] | MessageLikeRepresentation
- REMOVE_ALL_MESSAGES = "__remove_all__"
- def _add_messages_wrapper(func: Callable) -> Callable[[Messages, Messages], Messages]:
- def _add_messages(
- left: Messages | None = None, right: Messages | None = None, **kwargs: Any
- ) -> Messages | Callable[[Messages, Messages], Messages]:
- if left is not None and right is not None:
- return func(left, right, **kwargs)
- elif left is not None or right is not None:
- msg = (
- f"Must specify non-null arguments for both 'left' and 'right'. Only "
- f"received: '{'left' if left else 'right'}'."
- )
- raise ValueError(msg)
- else:
- return partial(func, **kwargs)
- _add_messages.__doc__ = func.__doc__
- return cast(Callable[[Messages, Messages], Messages], _add_messages)
- @_add_messages_wrapper
- def add_messages(
- left: Messages,
- right: Messages,
- *,
- format: Literal["langchain-openai"] | None = None,
- ) -> Messages:
- """Merges two lists of messages, updating existing messages by ID.
- By default, this ensures the state is "append-only", unless the
- new message has the same ID as an existing message.
- Args:
- left: The base list of `Messages`.
- right: The list of `Messages` (or single `Message`) to merge
- into the base list.
- format: The format to return messages in. If `None` then `Messages` will be
- returned as is. If `langchain-openai` then `Messages` will be returned as
- `BaseMessage` objects with their contents formatted to match OpenAI message
- format, meaning contents can be string, `'text'` blocks, or `'image_url'` blocks
- and tool responses are returned as their own `ToolMessage` objects.
- !!! important "Requirement"
- Must have `langchain-core>=0.3.11` installed to use this feature.
- Returns:
- A new list of messages with the messages from `right` merged into `left`.
- If a message in `right` has the same ID as a message in `left`, the
- message from `right` will replace the message from `left`.
- Example: Basic usage
- ```python
- from langchain_core.messages import AIMessage, HumanMessage
- msgs1 = [HumanMessage(content="Hello", id="1")]
- msgs2 = [AIMessage(content="Hi there!", id="2")]
- add_messages(msgs1, msgs2)
- # [HumanMessage(content='Hello', id='1'), AIMessage(content='Hi there!', id='2')]
- ```
- Example: Overwrite existing message
- ```python
- msgs1 = [HumanMessage(content="Hello", id="1")]
- msgs2 = [HumanMessage(content="Hello again", id="1")]
- add_messages(msgs1, msgs2)
- # [HumanMessage(content='Hello again', id='1')]
- ```
- Example: Use in a StateGraph
- ```python
- from typing import Annotated
- from typing_extensions import TypedDict
- from langgraph.graph import StateGraph
- class State(TypedDict):
- messages: Annotated[list, add_messages]
- builder = StateGraph(State)
- builder.add_node("chatbot", lambda state: {"messages": [("assistant", "Hello")]})
- builder.set_entry_point("chatbot")
- builder.set_finish_point("chatbot")
- graph = builder.compile()
- graph.invoke({})
- # {'messages': [AIMessage(content='Hello', id=...)]}
- ```
- Example: Use OpenAI message format
- ```python
- from typing import Annotated
- from typing_extensions import TypedDict
- from langgraph.graph import StateGraph, add_messages
- class State(TypedDict):
- messages: Annotated[list, add_messages(format="langchain-openai")]
- def chatbot_node(state: State) -> list:
- return {
- "messages": [
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": "Here's an image:",
- "cache_control": {"type": "ephemeral"},
- },
- {
- "type": "image",
- "source": {
- "type": "base64",
- "media_type": "image/jpeg",
- "data": "1234",
- },
- },
- ],
- },
- ]
- }
- builder = StateGraph(State)
- builder.add_node("chatbot", chatbot_node)
- builder.set_entry_point("chatbot")
- builder.set_finish_point("chatbot")
- graph = builder.compile()
- graph.invoke({"messages": []})
- # {
- # 'messages': [
- # HumanMessage(
- # content=[
- # {"type": "text", "text": "Here's an image:"},
- # {
- # "type": "image_url",
- # "image_url": {"url": ""},
- # },
- # ],
- # ),
- # ]
- # }
- ```
- """
- remove_all_idx = None
- # coerce to list
- if not isinstance(left, list):
- left = [left] # type: ignore[assignment]
- if not isinstance(right, list):
- right = [right] # type: ignore[assignment]
- # coerce to message
- left = [
- message_chunk_to_message(cast(BaseMessageChunk, m))
- for m in convert_to_messages(left)
- ]
- right = [
- message_chunk_to_message(cast(BaseMessageChunk, m))
- for m in convert_to_messages(right)
- ]
- # assign missing ids
- for m in left:
- if m.id is None:
- m.id = str(uuid.uuid4())
- for idx, m in enumerate(right):
- if m.id is None:
- m.id = str(uuid.uuid4())
- if isinstance(m, RemoveMessage) and m.id == REMOVE_ALL_MESSAGES:
- remove_all_idx = idx
- if remove_all_idx is not None:
- return right[remove_all_idx + 1 :]
- # merge
- merged = left.copy()
- merged_by_id = {m.id: i for i, m in enumerate(merged)}
- ids_to_remove = set()
- for m in right:
- if (existing_idx := merged_by_id.get(m.id)) is not None:
- if isinstance(m, RemoveMessage):
- ids_to_remove.add(m.id)
- else:
- ids_to_remove.discard(m.id)
- merged[existing_idx] = m
- else:
- if isinstance(m, RemoveMessage):
- raise ValueError(
- f"Attempting to delete a message with an ID that doesn't exist ('{m.id}')"
- )
- merged_by_id[m.id] = len(merged)
- merged.append(m)
- merged = [m for m in merged if m.id not in ids_to_remove]
- if format == "langchain-openai":
- merged = _format_messages(merged)
- elif format:
- msg = f"Unrecognized {format=}. Expected one of 'langchain-openai', None."
- raise ValueError(msg)
- else:
- pass
- return merged
- @deprecated(
- "MessageGraph is deprecated in langgraph 1.0.0, to be removed in 2.0.0. Please use StateGraph with a `messages` key instead.",
- category=None,
- )
- class MessageGraph(StateGraph):
- """A StateGraph where every node receives a list of messages as input and returns one or more messages as output.
- MessageGraph is a subclass of StateGraph whose entire state is a single, append-only* list of messages.
- Each node in a MessageGraph takes a list of messages as input and returns zero or more
- messages as output. The `add_messages` function is used to merge the output messages from each node
- into the existing list of messages in the graph's state.
- Examples:
- ```pycon
- >>> from langgraph.graph.message import MessageGraph
- ...
- >>> builder = MessageGraph()
- >>> builder.add_node("chatbot", lambda state: [("assistant", "Hello!")])
- >>> builder.set_entry_point("chatbot")
- >>> builder.set_finish_point("chatbot")
- >>> builder.compile().invoke([("user", "Hi there.")])
- [HumanMessage(content="Hi there.", id='...'), AIMessage(content="Hello!", id='...')]
- ```
- ```pycon
- >>> from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
- >>> from langgraph.graph.message import MessageGraph
- ...
- >>> builder = MessageGraph()
- >>> builder.add_node(
- ... "chatbot",
- ... lambda state: [
- ... AIMessage(
- ... content="Hello!",
- ... tool_calls=[{"name": "search", "id": "123", "args": {"query": "X"}}],
- ... )
- ... ],
- ... )
- >>> builder.add_node(
- ... "search", lambda state: [ToolMessage(content="Searching...", tool_call_id="123")]
- ... )
- >>> builder.set_entry_point("chatbot")
- >>> builder.add_edge("chatbot", "search")
- >>> builder.set_finish_point("search")
- >>> builder.compile().invoke([HumanMessage(content="Hi there. Can you search for X?")])
- {'messages': [HumanMessage(content="Hi there. Can you search for X?", id='b8b7d8f4-7f4d-4f4d-9c1d-f8b8d8f4d9c1'),
- AIMessage(content="Hello!", id='f4d9c1d8-8d8f-4d9c-b8b7-d8f4f4d9c1d8'),
- ToolMessage(content="Searching...", id='d8f4f4d9-c1d8-4f4d-b8b7-d8f4f4d9c1d8', tool_call_id="123")]}
- ```
- """
- def __init__(self) -> None:
- warnings.warn(
- "MessageGraph is deprecated in LangGraph v1.0.0, to be removed in v2.0.0. Please use StateGraph with a `messages` key instead.",
- category=LangGraphDeprecatedSinceV10,
- stacklevel=2,
- )
- super().__init__(Annotated[list[AnyMessage], add_messages]) # type: ignore[arg-type]
- class MessagesState(TypedDict):
- messages: Annotated[list[AnyMessage], add_messages]
- def _format_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
- try:
- from langchain_core.messages import convert_to_openai_messages
- except ImportError:
- msg = (
- "Must have langchain-core>=0.3.11 installed to use automatic message "
- "formatting (format='langchain-openai'). Please update your langchain-core "
- "version or remove the 'format' flag. Returning un-formatted "
- "messages."
- )
- warnings.warn(msg)
- return list(messages)
- else:
- return convert_to_messages(convert_to_openai_messages(messages))
- def push_message(
- message: MessageLikeRepresentation | BaseMessageChunk,
- *,
- state_key: str | None = "messages",
- ) -> AnyMessage:
- """Write a message manually to the `messages` / `messages-tuple` stream mode.
- Will automatically write to the channel specified in the `state_key` unless `state_key` is `None`.
- """
- from langchain_core.callbacks.base import (
- BaseCallbackHandler,
- BaseCallbackManager,
- )
- from langgraph.config import get_config
- from langgraph.pregel._messages import StreamMessagesHandler
- config = get_config()
- message = next(x for x in convert_to_messages([message]))
- if message.id is None:
- raise ValueError("Message ID is required")
- if isinstance(config["callbacks"], BaseCallbackManager):
- manager = config["callbacks"]
- handlers = manager.handlers
- elif isinstance(config["callbacks"], list) and all(
- isinstance(x, BaseCallbackHandler) for x in config["callbacks"]
- ):
- handlers = config["callbacks"]
- if stream_handler := next(
- (x for x in handlers if isinstance(x, StreamMessagesHandler)), None
- ):
- metadata = config["metadata"]
- message_meta = (
- tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP)),
- metadata,
- )
- stream_handler._emit(message_meta, message, dedupe=False)
- if state_key:
- config[CONF][CONFIG_KEY_SEND]([(state_key, message)])
- return message
|