message.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. from __future__ import annotations
  2. import uuid
  3. import warnings
  4. from collections.abc import Callable, Sequence
  5. from functools import partial
  6. from typing import (
  7. Annotated,
  8. Any,
  9. Literal,
  10. cast,
  11. )
  12. from langchain_core.messages import (
  13. AnyMessage,
  14. BaseMessage,
  15. BaseMessageChunk,
  16. MessageLikeRepresentation,
  17. RemoveMessage,
  18. convert_to_messages,
  19. message_chunk_to_message,
  20. )
  21. from typing_extensions import TypedDict, deprecated
  22. from langgraph._internal._constants import CONF, CONFIG_KEY_SEND, NS_SEP
  23. from langgraph.graph.state import StateGraph
  24. from langgraph.warnings import LangGraphDeprecatedSinceV10
  25. __all__ = (
  26. "add_messages",
  27. "MessagesState",
  28. "MessageGraph",
  29. "REMOVE_ALL_MESSAGES",
  30. )
  31. Messages = list[MessageLikeRepresentation] | MessageLikeRepresentation
  32. REMOVE_ALL_MESSAGES = "__remove_all__"
  33. def _add_messages_wrapper(func: Callable) -> Callable[[Messages, Messages], Messages]:
  34. def _add_messages(
  35. left: Messages | None = None, right: Messages | None = None, **kwargs: Any
  36. ) -> Messages | Callable[[Messages, Messages], Messages]:
  37. if left is not None and right is not None:
  38. return func(left, right, **kwargs)
  39. elif left is not None or right is not None:
  40. msg = (
  41. f"Must specify non-null arguments for both 'left' and 'right'. Only "
  42. f"received: '{'left' if left else 'right'}'."
  43. )
  44. raise ValueError(msg)
  45. else:
  46. return partial(func, **kwargs)
  47. _add_messages.__doc__ = func.__doc__
  48. return cast(Callable[[Messages, Messages], Messages], _add_messages)
  49. @_add_messages_wrapper
  50. def add_messages(
  51. left: Messages,
  52. right: Messages,
  53. *,
  54. format: Literal["langchain-openai"] | None = None,
  55. ) -> Messages:
  56. """Merges two lists of messages, updating existing messages by ID.
  57. By default, this ensures the state is "append-only", unless the
  58. new message has the same ID as an existing message.
  59. Args:
  60. left: The base list of `Messages`.
  61. right: The list of `Messages` (or single `Message`) to merge
  62. into the base list.
  63. format: The format to return messages in. If `None` then `Messages` will be
  64. returned as is. If `langchain-openai` then `Messages` will be returned as
  65. `BaseMessage` objects with their contents formatted to match OpenAI message
  66. format, meaning contents can be string, `'text'` blocks, or `'image_url'` blocks
  67. and tool responses are returned as their own `ToolMessage` objects.
  68. !!! important "Requirement"
  69. Must have `langchain-core>=0.3.11` installed to use this feature.
  70. Returns:
  71. A new list of messages with the messages from `right` merged into `left`.
  72. If a message in `right` has the same ID as a message in `left`, the
  73. message from `right` will replace the message from `left`.
  74. Example: Basic usage
  75. ```python
  76. from langchain_core.messages import AIMessage, HumanMessage
  77. msgs1 = [HumanMessage(content="Hello", id="1")]
  78. msgs2 = [AIMessage(content="Hi there!", id="2")]
  79. add_messages(msgs1, msgs2)
  80. # [HumanMessage(content='Hello', id='1'), AIMessage(content='Hi there!', id='2')]
  81. ```
  82. Example: Overwrite existing message
  83. ```python
  84. msgs1 = [HumanMessage(content="Hello", id="1")]
  85. msgs2 = [HumanMessage(content="Hello again", id="1")]
  86. add_messages(msgs1, msgs2)
  87. # [HumanMessage(content='Hello again', id='1')]
  88. ```
  89. Example: Use in a StateGraph
  90. ```python
  91. from typing import Annotated
  92. from typing_extensions import TypedDict
  93. from langgraph.graph import StateGraph
  94. class State(TypedDict):
  95. messages: Annotated[list, add_messages]
  96. builder = StateGraph(State)
  97. builder.add_node("chatbot", lambda state: {"messages": [("assistant", "Hello")]})
  98. builder.set_entry_point("chatbot")
  99. builder.set_finish_point("chatbot")
  100. graph = builder.compile()
  101. graph.invoke({})
  102. # {'messages': [AIMessage(content='Hello', id=...)]}
  103. ```
  104. Example: Use OpenAI message format
  105. ```python
  106. from typing import Annotated
  107. from typing_extensions import TypedDict
  108. from langgraph.graph import StateGraph, add_messages
  109. class State(TypedDict):
  110. messages: Annotated[list, add_messages(format="langchain-openai")]
  111. def chatbot_node(state: State) -> list:
  112. return {
  113. "messages": [
  114. {
  115. "role": "user",
  116. "content": [
  117. {
  118. "type": "text",
  119. "text": "Here's an image:",
  120. "cache_control": {"type": "ephemeral"},
  121. },
  122. {
  123. "type": "image",
  124. "source": {
  125. "type": "base64",
  126. "media_type": "image/jpeg",
  127. "data": "1234",
  128. },
  129. },
  130. ],
  131. },
  132. ]
  133. }
  134. builder = StateGraph(State)
  135. builder.add_node("chatbot", chatbot_node)
  136. builder.set_entry_point("chatbot")
  137. builder.set_finish_point("chatbot")
  138. graph = builder.compile()
  139. graph.invoke({"messages": []})
  140. # {
  141. # 'messages': [
  142. # HumanMessage(
  143. # content=[
  144. # {"type": "text", "text": "Here's an image:"},
  145. # {
  146. # "type": "image_url",
  147. # "image_url": {"url": ""},
  148. # },
  149. # ],
  150. # ),
  151. # ]
  152. # }
  153. ```
  154. """
  155. remove_all_idx = None
  156. # coerce to list
  157. if not isinstance(left, list):
  158. left = [left] # type: ignore[assignment]
  159. if not isinstance(right, list):
  160. right = [right] # type: ignore[assignment]
  161. # coerce to message
  162. left = [
  163. message_chunk_to_message(cast(BaseMessageChunk, m))
  164. for m in convert_to_messages(left)
  165. ]
  166. right = [
  167. message_chunk_to_message(cast(BaseMessageChunk, m))
  168. for m in convert_to_messages(right)
  169. ]
  170. # assign missing ids
  171. for m in left:
  172. if m.id is None:
  173. m.id = str(uuid.uuid4())
  174. for idx, m in enumerate(right):
  175. if m.id is None:
  176. m.id = str(uuid.uuid4())
  177. if isinstance(m, RemoveMessage) and m.id == REMOVE_ALL_MESSAGES:
  178. remove_all_idx = idx
  179. if remove_all_idx is not None:
  180. return right[remove_all_idx + 1 :]
  181. # merge
  182. merged = left.copy()
  183. merged_by_id = {m.id: i for i, m in enumerate(merged)}
  184. ids_to_remove = set()
  185. for m in right:
  186. if (existing_idx := merged_by_id.get(m.id)) is not None:
  187. if isinstance(m, RemoveMessage):
  188. ids_to_remove.add(m.id)
  189. else:
  190. ids_to_remove.discard(m.id)
  191. merged[existing_idx] = m
  192. else:
  193. if isinstance(m, RemoveMessage):
  194. raise ValueError(
  195. f"Attempting to delete a message with an ID that doesn't exist ('{m.id}')"
  196. )
  197. merged_by_id[m.id] = len(merged)
  198. merged.append(m)
  199. merged = [m for m in merged if m.id not in ids_to_remove]
  200. if format == "langchain-openai":
  201. merged = _format_messages(merged)
  202. elif format:
  203. msg = f"Unrecognized {format=}. Expected one of 'langchain-openai', None."
  204. raise ValueError(msg)
  205. else:
  206. pass
  207. return merged
  208. @deprecated(
  209. "MessageGraph is deprecated in langgraph 1.0.0, to be removed in 2.0.0. Please use StateGraph with a `messages` key instead.",
  210. category=None,
  211. )
  212. class MessageGraph(StateGraph):
  213. """A StateGraph where every node receives a list of messages as input and returns one or more messages as output.
  214. MessageGraph is a subclass of StateGraph whose entire state is a single, append-only* list of messages.
  215. Each node in a MessageGraph takes a list of messages as input and returns zero or more
  216. messages as output. The `add_messages` function is used to merge the output messages from each node
  217. into the existing list of messages in the graph's state.
  218. Examples:
  219. ```pycon
  220. >>> from langgraph.graph.message import MessageGraph
  221. ...
  222. >>> builder = MessageGraph()
  223. >>> builder.add_node("chatbot", lambda state: [("assistant", "Hello!")])
  224. >>> builder.set_entry_point("chatbot")
  225. >>> builder.set_finish_point("chatbot")
  226. >>> builder.compile().invoke([("user", "Hi there.")])
  227. [HumanMessage(content="Hi there.", id='...'), AIMessage(content="Hello!", id='...')]
  228. ```
  229. ```pycon
  230. >>> from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
  231. >>> from langgraph.graph.message import MessageGraph
  232. ...
  233. >>> builder = MessageGraph()
  234. >>> builder.add_node(
  235. ... "chatbot",
  236. ... lambda state: [
  237. ... AIMessage(
  238. ... content="Hello!",
  239. ... tool_calls=[{"name": "search", "id": "123", "args": {"query": "X"}}],
  240. ... )
  241. ... ],
  242. ... )
  243. >>> builder.add_node(
  244. ... "search", lambda state: [ToolMessage(content="Searching...", tool_call_id="123")]
  245. ... )
  246. >>> builder.set_entry_point("chatbot")
  247. >>> builder.add_edge("chatbot", "search")
  248. >>> builder.set_finish_point("search")
  249. >>> builder.compile().invoke([HumanMessage(content="Hi there. Can you search for X?")])
  250. {'messages': [HumanMessage(content="Hi there. Can you search for X?", id='b8b7d8f4-7f4d-4f4d-9c1d-f8b8d8f4d9c1'),
  251. AIMessage(content="Hello!", id='f4d9c1d8-8d8f-4d9c-b8b7-d8f4f4d9c1d8'),
  252. ToolMessage(content="Searching...", id='d8f4f4d9-c1d8-4f4d-b8b7-d8f4f4d9c1d8', tool_call_id="123")]}
  253. ```
  254. """
  255. def __init__(self) -> None:
  256. warnings.warn(
  257. "MessageGraph is deprecated in LangGraph v1.0.0, to be removed in v2.0.0. Please use StateGraph with a `messages` key instead.",
  258. category=LangGraphDeprecatedSinceV10,
  259. stacklevel=2,
  260. )
  261. super().__init__(Annotated[list[AnyMessage], add_messages]) # type: ignore[arg-type]
  262. class MessagesState(TypedDict):
  263. messages: Annotated[list[AnyMessage], add_messages]
  264. def _format_messages(messages: Sequence[BaseMessage]) -> list[BaseMessage]:
  265. try:
  266. from langchain_core.messages import convert_to_openai_messages
  267. except ImportError:
  268. msg = (
  269. "Must have langchain-core>=0.3.11 installed to use automatic message "
  270. "formatting (format='langchain-openai'). Please update your langchain-core "
  271. "version or remove the 'format' flag. Returning un-formatted "
  272. "messages."
  273. )
  274. warnings.warn(msg)
  275. return list(messages)
  276. else:
  277. return convert_to_messages(convert_to_openai_messages(messages))
  278. def push_message(
  279. message: MessageLikeRepresentation | BaseMessageChunk,
  280. *,
  281. state_key: str | None = "messages",
  282. ) -> AnyMessage:
  283. """Write a message manually to the `messages` / `messages-tuple` stream mode.
  284. Will automatically write to the channel specified in the `state_key` unless `state_key` is `None`.
  285. """
  286. from langchain_core.callbacks.base import (
  287. BaseCallbackHandler,
  288. BaseCallbackManager,
  289. )
  290. from langgraph.config import get_config
  291. from langgraph.pregel._messages import StreamMessagesHandler
  292. config = get_config()
  293. message = next(x for x in convert_to_messages([message]))
  294. if message.id is None:
  295. raise ValueError("Message ID is required")
  296. if isinstance(config["callbacks"], BaseCallbackManager):
  297. manager = config["callbacks"]
  298. handlers = manager.handlers
  299. elif isinstance(config["callbacks"], list) and all(
  300. isinstance(x, BaseCallbackHandler) for x in config["callbacks"]
  301. ):
  302. handlers = config["callbacks"]
  303. if stream_handler := next(
  304. (x for x in handlers if isinstance(x, StreamMessagesHandler)), None
  305. ):
  306. metadata = config["metadata"]
  307. message_meta = (
  308. tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP)),
  309. metadata,
  310. )
  311. stream_handler._emit(message_meta, message, dedupe=False)
  312. if state_key:
  313. config[CONF][CONFIG_KEY_SEND]([(state_key, message)])
  314. return message