_messages.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. from __future__ import annotations
  2. from collections.abc import AsyncIterator, Callable, Iterator, Sequence
  3. from typing import (
  4. Any,
  5. TypeVar,
  6. cast,
  7. )
  8. from uuid import UUID, uuid4
  9. from langchain_core.callbacks import BaseCallbackHandler
  10. from langchain_core.messages import BaseMessage
  11. from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, LLMResult
  12. from langgraph._internal._constants import NS_SEP
  13. from langgraph.constants import TAG_HIDDEN, TAG_NOSTREAM
  14. from langgraph.pregel.protocol import StreamChunk
  15. from langgraph.types import Command
  16. try:
  17. from langchain_core.tracers._streaming import _StreamingCallbackHandler
  18. except ImportError:
  19. _StreamingCallbackHandler = object # type: ignore
  20. T = TypeVar("T")
  21. Meta = tuple[tuple[str, ...], dict[str, Any]]
  22. class StreamMessagesHandler(BaseCallbackHandler, _StreamingCallbackHandler):
  23. """A callback handler that implements stream_mode=messages.
  24. Collects messages from:
  25. (1) chat model stream events; and
  26. (2) node outputs.
  27. """
  28. run_inline = True
  29. """We want this callback to run in the main thread to avoid order/locking issues."""
  30. def __init__(
  31. self,
  32. stream: Callable[[StreamChunk], None],
  33. subgraphs: bool,
  34. *,
  35. parent_ns: tuple[str, ...] | None = None,
  36. ) -> None:
  37. """Configure the handler to stream messages from LLMs and nodes.
  38. Args:
  39. stream: A callable that takes a StreamChunk and emits it.
  40. subgraphs: Whether to emit messages from subgraphs.
  41. parent_ns: The namespace where the handler was created.
  42. We keep track of this namespace to allow calls to subgraphs that
  43. were explicitly requested as a stream with `messages` mode
  44. configured.
  45. Example:
  46. parent_ns is used to handle scenarios where the subgraph is explicitly
  47. streamed with `stream_mode="messages"`.
  48. ```python
  49. def parent_graph_node():
  50. # This node is in the parent graph.
  51. async for event in some_subgraph(..., stream_mode="messages"):
  52. do something with event # <-- these events will be emitted
  53. return ...
  54. parent_graph.invoke(subgraphs=False)
  55. ```
  56. """
  57. self.stream = stream
  58. self.subgraphs = subgraphs
  59. self.metadata: dict[UUID, Meta] = {}
  60. self.seen: set[int | str] = set()
  61. self.parent_ns = parent_ns
  62. def _emit(self, meta: Meta, message: BaseMessage, *, dedupe: bool = False) -> None:
  63. if dedupe and message.id in self.seen:
  64. return
  65. else:
  66. if message.id is None:
  67. message.id = str(uuid4())
  68. self.seen.add(message.id)
  69. self.stream((meta[0], "messages", (message, meta[1])))
  70. def _find_and_emit_messages(self, meta: Meta, response: Any) -> None:
  71. if isinstance(response, BaseMessage):
  72. self._emit(meta, response, dedupe=True)
  73. elif isinstance(response, Sequence):
  74. for value in response:
  75. if isinstance(value, BaseMessage):
  76. self._emit(meta, value, dedupe=True)
  77. elif isinstance(response, dict):
  78. for value in response.values():
  79. if isinstance(value, BaseMessage):
  80. self._emit(meta, value, dedupe=True)
  81. elif isinstance(value, Sequence):
  82. for item in value:
  83. if isinstance(item, BaseMessage):
  84. self._emit(meta, item, dedupe=True)
  85. elif hasattr(response, "__dir__") and callable(response.__dir__):
  86. for key in dir(response):
  87. try:
  88. value = getattr(response, key)
  89. if isinstance(value, BaseMessage):
  90. self._emit(meta, value, dedupe=True)
  91. elif isinstance(value, Sequence):
  92. for item in value:
  93. if isinstance(item, BaseMessage):
  94. self._emit(meta, item, dedupe=True)
  95. except AttributeError:
  96. pass
  97. def tap_output_aiter(
  98. self, run_id: UUID, output: AsyncIterator[T]
  99. ) -> AsyncIterator[T]:
  100. return output
  101. def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
  102. return output
  103. def on_chat_model_start(
  104. self,
  105. serialized: dict[str, Any],
  106. messages: list[list[BaseMessage]],
  107. *,
  108. run_id: UUID,
  109. parent_run_id: UUID | None = None,
  110. tags: list[str] | None = None,
  111. metadata: dict[str, Any] | None = None,
  112. **kwargs: Any,
  113. ) -> Any:
  114. if metadata and (not tags or (TAG_NOSTREAM not in tags)):
  115. ns = tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP))[
  116. :-1
  117. ]
  118. if not self.subgraphs and len(ns) > 0 and ns != self.parent_ns:
  119. return
  120. if tags:
  121. if filtered_tags := [t for t in tags if not t.startswith("seq:step")]:
  122. metadata["tags"] = filtered_tags
  123. self.metadata[run_id] = (ns, metadata)
  124. def on_llm_new_token(
  125. self,
  126. token: str,
  127. *,
  128. chunk: ChatGenerationChunk | None = None,
  129. run_id: UUID,
  130. parent_run_id: UUID | None = None,
  131. tags: list[str] | None = None,
  132. **kwargs: Any,
  133. ) -> Any:
  134. if not isinstance(chunk, ChatGenerationChunk):
  135. return
  136. if meta := self.metadata.get(run_id):
  137. self._emit(meta, chunk.message)
  138. def on_llm_end(
  139. self,
  140. response: LLMResult,
  141. *,
  142. run_id: UUID,
  143. parent_run_id: UUID | None = None,
  144. **kwargs: Any,
  145. ) -> Any:
  146. if meta := self.metadata.get(run_id):
  147. if response.generations and response.generations[0]:
  148. gen = response.generations[0][0]
  149. if isinstance(gen, ChatGeneration):
  150. self._emit(meta, gen.message, dedupe=True)
  151. self.metadata.pop(run_id, None)
  152. def on_llm_error(
  153. self,
  154. error: BaseException,
  155. *,
  156. run_id: UUID,
  157. parent_run_id: UUID | None = None,
  158. **kwargs: Any,
  159. ) -> Any:
  160. self.metadata.pop(run_id, None)
  161. def on_chain_start(
  162. self,
  163. serialized: dict[str, Any],
  164. inputs: dict[str, Any],
  165. *,
  166. run_id: UUID,
  167. parent_run_id: UUID | None = None,
  168. tags: list[str] | None = None,
  169. metadata: dict[str, Any] | None = None,
  170. **kwargs: Any,
  171. ) -> Any:
  172. if (
  173. metadata
  174. and kwargs.get("name") == metadata.get("langgraph_node")
  175. and (not tags or TAG_HIDDEN not in tags)
  176. ):
  177. ns = tuple(cast(str, metadata["langgraph_checkpoint_ns"]).split(NS_SEP))[
  178. :-1
  179. ]
  180. if not self.subgraphs and len(ns) > 0:
  181. return
  182. self.metadata[run_id] = (ns, metadata)
  183. if isinstance(inputs, dict):
  184. for key, value in inputs.items():
  185. if isinstance(value, BaseMessage):
  186. if value.id is not None:
  187. self.seen.add(value.id)
  188. elif isinstance(value, Sequence) and not isinstance(value, str):
  189. for item in value:
  190. if isinstance(item, BaseMessage):
  191. if item.id is not None:
  192. self.seen.add(item.id)
  193. def on_chain_end(
  194. self,
  195. response: Any,
  196. *,
  197. run_id: UUID,
  198. parent_run_id: UUID | None = None,
  199. **kwargs: Any,
  200. ) -> Any:
  201. if meta := self.metadata.pop(run_id, None):
  202. # Handle Command node updates
  203. if isinstance(response, Command):
  204. self._find_and_emit_messages(meta, response.update)
  205. # Handle list of Command updates
  206. elif isinstance(response, Sequence) and any(
  207. isinstance(value, Command) for value in response
  208. ):
  209. for value in response:
  210. if isinstance(value, Command):
  211. self._find_and_emit_messages(meta, value.update)
  212. else:
  213. self._find_and_emit_messages(meta, value)
  214. # Handle basic updates / streaming
  215. else:
  216. self._find_and_emit_messages(meta, response)
  217. def on_chain_error(
  218. self,
  219. error: BaseException,
  220. *,
  221. run_id: UUID,
  222. parent_run_id: UUID | None = None,
  223. **kwargs: Any,
  224. ) -> Any:
  225. self.metadata.pop(run_id, None)