remote.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015
  1. from __future__ import annotations
  2. import logging
  3. from collections.abc import AsyncIterator, Iterator, Sequence
  4. from dataclasses import asdict
  5. from typing import (
  6. Any,
  7. Literal,
  8. cast,
  9. )
  10. from uuid import UUID
  11. import langsmith as ls
  12. from langchain_core.runnables import RunnableConfig
  13. from langchain_core.runnables.graph import (
  14. Edge as DrawableEdge,
  15. )
  16. from langchain_core.runnables.graph import (
  17. Graph as DrawableGraph,
  18. )
  19. from langchain_core.runnables.graph import (
  20. Node as DrawableNode,
  21. )
  22. from langgraph.checkpoint.base import CheckpointMetadata
  23. from langgraph_sdk.client import (
  24. LangGraphClient,
  25. SyncLangGraphClient,
  26. get_client,
  27. get_sync_client,
  28. )
  29. from langgraph_sdk.schema import (
  30. Checkpoint,
  31. QueryParamTypes,
  32. ThreadState,
  33. )
  34. from langgraph_sdk.schema import (
  35. Command as CommandSDK,
  36. )
  37. from langgraph_sdk.schema import (
  38. StreamMode as StreamModeSDK,
  39. )
  40. from typing_extensions import Self
  41. from langgraph._internal._config import merge_configs
  42. from langgraph._internal._constants import (
  43. CONF,
  44. CONFIG_KEY_CHECKPOINT_ID,
  45. CONFIG_KEY_CHECKPOINT_MAP,
  46. CONFIG_KEY_CHECKPOINT_NS,
  47. CONFIG_KEY_STREAM,
  48. CONFIG_KEY_TASK_ID,
  49. INTERRUPT,
  50. NS_SEP,
  51. )
  52. from langgraph.errors import GraphInterrupt, ParentCommand
  53. from langgraph.pregel.protocol import PregelProtocol, StreamProtocol
  54. from langgraph.types import (
  55. All,
  56. Command,
  57. Interrupt,
  58. PregelTask,
  59. StateSnapshot,
  60. StreamMode,
  61. )
  62. logger = logging.getLogger(__name__)
  63. __all__ = ("RemoteGraph", "RemoteException")
  64. _CONF_DROPLIST = frozenset(
  65. (
  66. CONFIG_KEY_CHECKPOINT_MAP,
  67. CONFIG_KEY_CHECKPOINT_ID,
  68. CONFIG_KEY_CHECKPOINT_NS,
  69. CONFIG_KEY_TASK_ID,
  70. ),
  71. )
  72. def _sanitize_config_value(v: Any) -> Any:
  73. """Recursively sanitize a config value to ensure it contains only primitives."""
  74. if isinstance(v, (str, int, float, bool, UUID)):
  75. return v
  76. elif isinstance(v, dict):
  77. sanitized_dict = {}
  78. for k, val in v.items():
  79. if isinstance(k, str):
  80. sanitized_value = _sanitize_config_value(val)
  81. if sanitized_value is not None:
  82. sanitized_dict[k] = sanitized_value
  83. return sanitized_dict
  84. elif isinstance(v, (list, tuple)):
  85. sanitized_list = []
  86. for item in v:
  87. sanitized_item = _sanitize_config_value(item)
  88. if sanitized_item is not None:
  89. sanitized_list.append(sanitized_item)
  90. return sanitized_list
  91. return None
  92. class RemoteException(Exception):
  93. """Exception raised when an error occurs in the remote graph."""
  94. pass
  95. class RemoteGraph(PregelProtocol):
  96. """The `RemoteGraph` class is a client implementation for calling remote
  97. APIs that implement the LangGraph Server API specification.
  98. For example, the `RemoteGraph` class can be used to call APIs from deployments
  99. on LangSmith Deployment.
  100. `RemoteGraph` behaves the same way as a `Graph` and can be used directly as
  101. a node in another `Graph`.
  102. """
  103. assistant_id: str
  104. name: str | None
  105. def __init__(
  106. self,
  107. assistant_id: str, # graph_id
  108. /,
  109. *,
  110. url: str | None = None,
  111. api_key: str | None = None,
  112. headers: dict[str, str] | None = None,
  113. client: LangGraphClient | None = None,
  114. sync_client: SyncLangGraphClient | None = None,
  115. config: RunnableConfig | None = None,
  116. name: str | None = None,
  117. distributed_tracing: bool = False,
  118. ):
  119. """Specify `url`, `api_key`, and/or `headers` to create default sync and async clients.
  120. If `client` or `sync_client` are provided, they will be used instead of the default clients.
  121. See `LangGraphClient` and `SyncLangGraphClient` for details on the default clients. At least
  122. one of `url`, `client`, or `sync_client` must be provided.
  123. Args:
  124. assistant_id: The assistant ID or graph name of the remote graph to use.
  125. url: The URL of the remote API.
  126. api_key: The API key to use for authentication. If not provided, it will be read from the environment (`LANGGRAPH_API_KEY`, `LANGSMITH_API_KEY`, or `LANGCHAIN_API_KEY`).
  127. headers: Additional headers to include in the requests.
  128. client: A `LangGraphClient` instance to use instead of creating a default client.
  129. sync_client: A `SyncLangGraphClient` instance to use instead of creating a default client.
  130. config: An optional `RunnableConfig` instance with additional configuration.
  131. name: Human-readable name to attach to the RemoteGraph instance.
  132. This is useful for adding `RemoteGraph` as a subgraph via `graph.add_node(remote_graph)`.
  133. If not provided, defaults to the assistant ID.
  134. distributed_tracing: Whether to enable sending LangSmith distributed tracing headers.
  135. """
  136. self.assistant_id = assistant_id
  137. if name is None:
  138. self.name = assistant_id
  139. else:
  140. self.name = name
  141. self.config = config
  142. self.distributed_tracing = distributed_tracing
  143. if client is None and url is not None:
  144. client = get_client(url=url, api_key=api_key, headers=headers)
  145. self.client = client
  146. if sync_client is None and url is not None:
  147. sync_client = get_sync_client(url=url, api_key=api_key, headers=headers)
  148. self.sync_client = sync_client
  149. def _validate_client(self) -> LangGraphClient:
  150. if self.client is None:
  151. raise ValueError(
  152. "Async client is not initialized: please provide `url` or `client` when initializing `RemoteGraph`."
  153. )
  154. return self.client
  155. def _validate_sync_client(self) -> SyncLangGraphClient:
  156. if self.sync_client is None:
  157. raise ValueError(
  158. "Sync client is not initialized: please provide `url` or `sync_client` when initializing `RemoteGraph`."
  159. )
  160. return self.sync_client
  161. def copy(self, update: dict[str, Any]) -> Self:
  162. attrs = {**self.__dict__, **update}
  163. return self.__class__(attrs.pop("assistant_id"), **attrs)
  164. def with_config(self, config: RunnableConfig | None = None, **kwargs: Any) -> Self:
  165. return self.copy(
  166. {"config": merge_configs(self.config, config, cast(RunnableConfig, kwargs))}
  167. )
  168. def _get_drawable_nodes(
  169. self, graph: dict[str, list[dict[str, Any]]]
  170. ) -> dict[str, DrawableNode]:
  171. nodes = {}
  172. for node in graph["nodes"]:
  173. node_id = str(node["id"])
  174. node_data = node.get("data", {})
  175. # Get node name from node_data if available. If not, use node_id.
  176. node_name = node.get("name")
  177. if node_name is None:
  178. if isinstance(node_data, dict):
  179. node_name = node_data.get("name", node_id)
  180. else:
  181. node_name = node_id
  182. nodes[node_id] = DrawableNode(
  183. id=node_id,
  184. name=node_name,
  185. data=node_data,
  186. metadata=node.get("metadata"),
  187. )
  188. return nodes
  189. def get_graph(
  190. self,
  191. config: RunnableConfig | None = None,
  192. *,
  193. xray: int | bool = False,
  194. headers: dict[str, str] | None = None,
  195. params: QueryParamTypes | None = None,
  196. ) -> DrawableGraph:
  197. """Get graph by graph name.
  198. This method calls `GET /assistants/{assistant_id}/graph`.
  199. Args:
  200. config: This parameter is not used.
  201. xray: Include graph representation of subgraphs. If an integer
  202. value is provided, only subgraphs with a depth less than or
  203. equal to the value will be included.
  204. Returns:
  205. The graph information for the assistant in JSON format.
  206. """
  207. sync_client = self._validate_sync_client()
  208. graph = sync_client.assistants.get_graph(
  209. assistant_id=self.assistant_id,
  210. xray=xray,
  211. headers=headers,
  212. params=params,
  213. )
  214. return DrawableGraph(
  215. nodes=self._get_drawable_nodes(graph),
  216. edges=[DrawableEdge(**edge) for edge in graph["edges"]],
  217. )
  218. async def aget_graph(
  219. self,
  220. config: RunnableConfig | None = None,
  221. *,
  222. xray: int | bool = False,
  223. headers: dict[str, str] | None = None,
  224. params: QueryParamTypes | None = None,
  225. ) -> DrawableGraph:
  226. """Get graph by graph name.
  227. This method calls `GET /assistants/{assistant_id}/graph`.
  228. Args:
  229. config: This parameter is not used.
  230. xray: Include graph representation of subgraphs. If an integer
  231. value is provided, only subgraphs with a depth less than or
  232. equal to the value will be included.
  233. Returns:
  234. The graph information for the assistant in JSON format.
  235. """
  236. client = self._validate_client()
  237. graph = await client.assistants.get_graph(
  238. assistant_id=self.assistant_id,
  239. xray=xray,
  240. headers=headers,
  241. params=params,
  242. )
  243. return DrawableGraph(
  244. nodes=self._get_drawable_nodes(graph),
  245. edges=[DrawableEdge(**edge) for edge in graph["edges"]],
  246. )
  247. def _create_state_snapshot(self, state: ThreadState) -> StateSnapshot:
  248. tasks: list[PregelTask] = []
  249. for task in state["tasks"]:
  250. interrupts = tuple(
  251. Interrupt(**interrupt) for interrupt in task["interrupts"]
  252. )
  253. tasks.append(
  254. PregelTask(
  255. id=task["id"],
  256. name=task["name"],
  257. path=tuple(),
  258. error=Exception(task["error"]) if task["error"] else None,
  259. interrupts=interrupts,
  260. state=(
  261. self._create_state_snapshot(task["state"])
  262. if task["state"]
  263. else (
  264. cast(RunnableConfig, {"configurable": task["checkpoint"]})
  265. if task["checkpoint"]
  266. else None
  267. )
  268. ),
  269. result=task.get("result"),
  270. )
  271. )
  272. return StateSnapshot(
  273. values=state["values"],
  274. next=tuple(state["next"]) if state["next"] else tuple(),
  275. config={
  276. "configurable": {
  277. "thread_id": state["checkpoint"]["thread_id"],
  278. "checkpoint_ns": state["checkpoint"]["checkpoint_ns"],
  279. "checkpoint_id": state["checkpoint"]["checkpoint_id"],
  280. "checkpoint_map": state["checkpoint"].get("checkpoint_map", {}),
  281. }
  282. },
  283. metadata=CheckpointMetadata(**state["metadata"]),
  284. created_at=state["created_at"],
  285. parent_config=(
  286. {
  287. "configurable": {
  288. "thread_id": state["parent_checkpoint"]["thread_id"],
  289. "checkpoint_ns": state["parent_checkpoint"]["checkpoint_ns"],
  290. "checkpoint_id": state["parent_checkpoint"]["checkpoint_id"],
  291. "checkpoint_map": state["parent_checkpoint"].get(
  292. "checkpoint_map", {}
  293. ),
  294. }
  295. }
  296. if state["parent_checkpoint"]
  297. else None
  298. ),
  299. tasks=tuple(tasks),
  300. interrupts=tuple([i for task in tasks for i in task.interrupts]),
  301. )
  302. def _get_checkpoint(self, config: RunnableConfig | None) -> Checkpoint | None:
  303. if config is None:
  304. return None
  305. checkpoint = {}
  306. if "thread_id" in config["configurable"]:
  307. checkpoint["thread_id"] = config["configurable"]["thread_id"]
  308. if "checkpoint_ns" in config["configurable"]:
  309. checkpoint["checkpoint_ns"] = config["configurable"]["checkpoint_ns"]
  310. if "checkpoint_id" in config["configurable"]:
  311. checkpoint["checkpoint_id"] = config["configurable"]["checkpoint_id"]
  312. if "checkpoint_map" in config["configurable"]:
  313. checkpoint["checkpoint_map"] = config["configurable"]["checkpoint_map"]
  314. return checkpoint if checkpoint else None
  315. def _get_config(self, checkpoint: Checkpoint) -> RunnableConfig:
  316. return {
  317. "configurable": {
  318. "thread_id": checkpoint["thread_id"],
  319. "checkpoint_ns": checkpoint["checkpoint_ns"],
  320. "checkpoint_id": checkpoint["checkpoint_id"],
  321. "checkpoint_map": checkpoint.get("checkpoint_map", {}),
  322. }
  323. }
  324. def _sanitize_config(self, config: RunnableConfig) -> RunnableConfig:
  325. """Sanitize the config to remove non-serializable fields."""
  326. sanitized: RunnableConfig = {}
  327. if "recursion_limit" in config:
  328. sanitized["recursion_limit"] = config["recursion_limit"]
  329. if "tags" in config:
  330. sanitized["tags"] = [tag for tag in config["tags"] if isinstance(tag, str)]
  331. if "metadata" in config:
  332. sanitized["metadata"] = {}
  333. for k, v in config["metadata"].items():
  334. if (
  335. isinstance(k, str)
  336. and (sanitized_value := _sanitize_config_value(v)) is not None
  337. ):
  338. sanitized["metadata"][k] = sanitized_value
  339. if "configurable" in config:
  340. sanitized["configurable"] = {}
  341. for k, v in config["configurable"].items():
  342. if (
  343. isinstance(k, str)
  344. and k not in _CONF_DROPLIST
  345. and (sanitized_value := _sanitize_config_value(v)) is not None
  346. ):
  347. sanitized["configurable"][k] = sanitized_value
  348. return sanitized
  349. def get_state(
  350. self,
  351. config: RunnableConfig,
  352. *,
  353. subgraphs: bool = False,
  354. headers: dict[str, str] | None = None,
  355. params: QueryParamTypes | None = None,
  356. ) -> StateSnapshot:
  357. """Get the state of a thread.
  358. This method calls `POST /threads/{thread_id}/state/checkpoint` if a
  359. checkpoint is specified in the config or `GET /threads/{thread_id}/state`
  360. if no checkpoint is specified.
  361. Args:
  362. config: A `RunnableConfig` that includes `thread_id` in the
  363. `configurable` field.
  364. subgraphs: Include subgraphs in the state.
  365. headers: Optional custom headers to include with the request.
  366. params: Optional query parameters to include with the request.
  367. Returns:
  368. The latest state of the thread.
  369. """
  370. sync_client = self._validate_sync_client()
  371. merged_config = merge_configs(self.config, config)
  372. state = sync_client.threads.get_state(
  373. thread_id=merged_config["configurable"]["thread_id"],
  374. checkpoint=self._get_checkpoint(merged_config),
  375. subgraphs=subgraphs,
  376. headers=headers,
  377. params=params,
  378. )
  379. return self._create_state_snapshot(state)
  380. async def aget_state(
  381. self,
  382. config: RunnableConfig,
  383. *,
  384. subgraphs: bool = False,
  385. headers: dict[str, str] | None = None,
  386. params: QueryParamTypes | None = None,
  387. ) -> StateSnapshot:
  388. """Get the state of a thread.
  389. This method calls `POST /threads/{thread_id}/state/checkpoint` if a
  390. checkpoint is specified in the config or `GET /threads/{thread_id}/state`
  391. if no checkpoint is specified.
  392. Args:
  393. config: A `RunnableConfig` that includes `thread_id` in the
  394. `configurable` field.
  395. subgraphs: Include subgraphs in the state.
  396. headers: Optional custom headers to include with the request.
  397. params: Optional query parameters to include with the request.
  398. Returns:
  399. The latest state of the thread.
  400. """
  401. client = self._validate_client()
  402. merged_config = merge_configs(self.config, config)
  403. state = await client.threads.get_state(
  404. thread_id=merged_config["configurable"]["thread_id"],
  405. checkpoint=self._get_checkpoint(merged_config),
  406. subgraphs=subgraphs,
  407. headers=headers,
  408. params=params,
  409. )
  410. return self._create_state_snapshot(state)
  411. def get_state_history(
  412. self,
  413. config: RunnableConfig,
  414. *,
  415. filter: dict[str, Any] | None = None,
  416. before: RunnableConfig | None = None,
  417. limit: int | None = None,
  418. headers: dict[str, str] | None = None,
  419. params: QueryParamTypes | None = None,
  420. ) -> Iterator[StateSnapshot]:
  421. """Get the state history of a thread.
  422. This method calls `POST /threads/{thread_id}/history`.
  423. Args:
  424. config: A `RunnableConfig` that includes `thread_id` in the
  425. `configurable` field.
  426. filter: Metadata to filter on.
  427. before: A `RunnableConfig` that includes checkpoint metadata.
  428. limit: Max number of states to return.
  429. Returns:
  430. States of the thread.
  431. """
  432. sync_client = self._validate_sync_client()
  433. merged_config = merge_configs(self.config, config)
  434. states = sync_client.threads.get_history(
  435. thread_id=merged_config["configurable"]["thread_id"],
  436. limit=limit if limit else 10,
  437. before=self._get_checkpoint(before),
  438. metadata=filter,
  439. checkpoint=self._get_checkpoint(merged_config),
  440. headers=headers,
  441. params=params,
  442. )
  443. for state in states:
  444. yield self._create_state_snapshot(state)
  445. async def aget_state_history(
  446. self,
  447. config: RunnableConfig,
  448. *,
  449. filter: dict[str, Any] | None = None,
  450. before: RunnableConfig | None = None,
  451. limit: int | None = None,
  452. headers: dict[str, str] | None = None,
  453. params: QueryParamTypes | None = None,
  454. ) -> AsyncIterator[StateSnapshot]:
  455. """Get the state history of a thread.
  456. This method calls `POST /threads/{thread_id}/history`.
  457. Args:
  458. config: A `RunnableConfig` that includes `thread_id` in the
  459. `configurable` field.
  460. filter: Metadata to filter on.
  461. before: A `RunnableConfig` that includes checkpoint metadata.
  462. limit: Max number of states to return.
  463. headers: Optional custom headers to include with the request.
  464. params: Optional query parameters to include with the request.
  465. Returns:
  466. States of the thread.
  467. """
  468. client = self._validate_client()
  469. merged_config = merge_configs(self.config, config)
  470. states = await client.threads.get_history(
  471. thread_id=merged_config["configurable"]["thread_id"],
  472. limit=limit if limit else 10,
  473. before=self._get_checkpoint(before),
  474. metadata=filter,
  475. checkpoint=self._get_checkpoint(merged_config),
  476. headers=headers,
  477. params=params,
  478. )
  479. for state in states:
  480. yield self._create_state_snapshot(state)
  481. def bulk_update_state(
  482. self,
  483. config: RunnableConfig,
  484. updates: list[tuple[dict[str, Any] | None, str | None]],
  485. ) -> RunnableConfig:
  486. raise NotImplementedError
  487. async def abulk_update_state(
  488. self,
  489. config: RunnableConfig,
  490. updates: list[tuple[dict[str, Any] | None, str | None]],
  491. ) -> RunnableConfig:
  492. raise NotImplementedError
  493. def update_state(
  494. self,
  495. config: RunnableConfig,
  496. values: dict[str, Any] | Any | None,
  497. as_node: str | None = None,
  498. *,
  499. headers: dict[str, str] | None = None,
  500. params: QueryParamTypes | None = None,
  501. ) -> RunnableConfig:
  502. """Update the state of a thread.
  503. This method calls `POST /threads/{thread_id}/state`.
  504. Args:
  505. config: A `RunnableConfig` that includes `thread_id` in the
  506. `configurable` field.
  507. values: Values to update to the state.
  508. as_node: Update the state as if this node had just executed.
  509. Returns:
  510. `RunnableConfig` for the updated thread.
  511. """
  512. sync_client = self._validate_sync_client()
  513. merged_config = merge_configs(self.config, config)
  514. response: dict = sync_client.threads.update_state( # type: ignore
  515. thread_id=merged_config["configurable"]["thread_id"],
  516. values=values,
  517. as_node=as_node,
  518. checkpoint=self._get_checkpoint(merged_config),
  519. headers=headers,
  520. params=params,
  521. )
  522. return self._get_config(response["checkpoint"])
  523. async def aupdate_state(
  524. self,
  525. config: RunnableConfig,
  526. values: dict[str, Any] | Any | None,
  527. as_node: str | None = None,
  528. *,
  529. headers: dict[str, str] | None = None,
  530. params: QueryParamTypes | None = None,
  531. ) -> RunnableConfig:
  532. """Update the state of a thread.
  533. This method calls `POST /threads/{thread_id}/state`.
  534. Args:
  535. config: A `RunnableConfig` that includes `thread_id` in the
  536. `configurable` field.
  537. values: Values to update to the state.
  538. as_node: Update the state as if this node had just executed.
  539. Returns:
  540. `RunnableConfig` for the updated thread.
  541. """
  542. client = self._validate_client()
  543. merged_config = merge_configs(self.config, config)
  544. response: dict = await client.threads.update_state( # type: ignore
  545. thread_id=merged_config["configurable"]["thread_id"],
  546. values=values,
  547. as_node=as_node,
  548. checkpoint=self._get_checkpoint(merged_config),
  549. headers=headers,
  550. params=params,
  551. )
  552. return self._get_config(response["checkpoint"])
  553. def _get_stream_modes(
  554. self,
  555. stream_mode: StreamMode | list[StreamMode] | None,
  556. config: RunnableConfig | None,
  557. default: StreamMode = "updates",
  558. ) -> tuple[list[StreamModeSDK], list[StreamModeSDK], bool, StreamProtocol | None]:
  559. """Return a tuple of the final list of stream modes sent to the
  560. remote graph and a boolean flag indicating if stream mode 'updates'
  561. was present in the original list of stream modes.
  562. 'updates' mode is added to the list of stream modes so that interrupts
  563. can be detected in the remote graph.
  564. """
  565. updated_stream_modes: list[StreamModeSDK] = []
  566. req_single = True
  567. # coerce to list, or add default stream mode
  568. if stream_mode:
  569. if isinstance(stream_mode, str):
  570. updated_stream_modes.append(stream_mode)
  571. else:
  572. req_single = False
  573. updated_stream_modes.extend(stream_mode)
  574. else:
  575. updated_stream_modes.append(default)
  576. requested_stream_modes = updated_stream_modes.copy()
  577. # add any from parent graph
  578. stream: StreamProtocol | None = (
  579. (config or {}).get(CONF, {}).get(CONFIG_KEY_STREAM)
  580. )
  581. if stream:
  582. updated_stream_modes.extend(stream.modes)
  583. # map "messages" to "messages-tuple"
  584. if "messages" in updated_stream_modes:
  585. updated_stream_modes.remove("messages")
  586. updated_stream_modes.append("messages-tuple")
  587. # if requested "messages-tuple",
  588. # map to "messages" in requested_stream_modes
  589. if "messages-tuple" in requested_stream_modes:
  590. requested_stream_modes.remove("messages-tuple")
  591. requested_stream_modes.append("messages")
  592. # add 'updates' mode if not present
  593. if "updates" not in updated_stream_modes:
  594. updated_stream_modes.append("updates")
  595. # remove 'events', as it's not supported in Pregel
  596. if "events" in updated_stream_modes:
  597. updated_stream_modes.remove("events")
  598. return (updated_stream_modes, requested_stream_modes, req_single, stream)
  599. def stream(
  600. self,
  601. input: dict[str, Any] | Any,
  602. config: RunnableConfig | None = None,
  603. *,
  604. stream_mode: StreamMode | list[StreamMode] | None = None,
  605. interrupt_before: All | Sequence[str] | None = None,
  606. interrupt_after: All | Sequence[str] | None = None,
  607. subgraphs: bool = False,
  608. headers: dict[str, str] | None = None,
  609. params: QueryParamTypes | None = None,
  610. **kwargs: Any,
  611. ) -> Iterator[dict[str, Any] | Any]:
  612. """Create a run and stream the results.
  613. This method calls `POST /threads/{thread_id}/runs/stream` if a `thread_id`
  614. is speciffed in the `configurable` field of the config or
  615. `POST /runs/stream` otherwise.
  616. Args:
  617. input: Input to the graph.
  618. config: A `RunnableConfig` for graph invocation.
  619. stream_mode: Stream mode(s) to use.
  620. interrupt_before: Interrupt the graph before these nodes.
  621. interrupt_after: Interrupt the graph after these nodes.
  622. subgraphs: Stream from subgraphs.
  623. headers: Additional headers to pass to the request.
  624. **kwargs: Additional params to pass to client.runs.stream.
  625. Yields:
  626. The output of the graph.
  627. """
  628. sync_client = self._validate_sync_client()
  629. merged_config = merge_configs(self.config, config)
  630. sanitized_config = self._sanitize_config(merged_config)
  631. stream_modes, requested, req_single, stream = self._get_stream_modes(
  632. stream_mode, config
  633. )
  634. if isinstance(input, Command):
  635. command: CommandSDK | None = cast(CommandSDK, asdict(input))
  636. input = None
  637. else:
  638. command = None
  639. thread_id = sanitized_config.get("configurable", {}).pop("thread_id", None)
  640. for chunk in sync_client.runs.stream(
  641. thread_id=thread_id,
  642. assistant_id=self.assistant_id,
  643. input=input,
  644. command=command,
  645. config=sanitized_config,
  646. stream_mode=stream_modes,
  647. interrupt_before=interrupt_before,
  648. interrupt_after=interrupt_after,
  649. stream_subgraphs=subgraphs or stream is not None,
  650. if_not_exists="create",
  651. headers=(
  652. _merge_tracing_headers(headers) if self.distributed_tracing else headers
  653. ),
  654. params=params,
  655. **kwargs,
  656. ):
  657. # split mode and ns
  658. if NS_SEP in chunk.event:
  659. mode, ns_ = chunk.event.split(NS_SEP, 1)
  660. ns = tuple(ns_.split(NS_SEP))
  661. else:
  662. mode, ns = chunk.event, ()
  663. # raise ParentCommand exception for command events
  664. if mode == "command" and chunk.data.get("graph") == Command.PARENT:
  665. raise ParentCommand(Command(**chunk.data))
  666. # prepend caller ns (as it is not passed to remote graph)
  667. if caller_ns := (config or {}).get(CONF, {}).get(CONFIG_KEY_CHECKPOINT_NS):
  668. caller_ns = tuple(caller_ns.split(NS_SEP))
  669. ns = caller_ns + ns
  670. # stream to parent stream
  671. if stream is not None and mode in stream.modes:
  672. stream((ns, mode, chunk.data))
  673. # raise interrupt or errors
  674. if chunk.event.startswith("updates"):
  675. if isinstance(chunk.data, dict) and INTERRUPT in chunk.data:
  676. if caller_ns:
  677. raise GraphInterrupt(
  678. [Interrupt(**i) for i in chunk.data[INTERRUPT]]
  679. )
  680. elif chunk.event.startswith("error"):
  681. raise RemoteException(chunk.data)
  682. # filter for what was actually requested
  683. if mode not in requested:
  684. continue
  685. if chunk.event.startswith("messages"):
  686. chunk = chunk._replace(data=tuple(chunk.data)) # type: ignore
  687. # emit chunk
  688. if subgraphs:
  689. if NS_SEP in chunk.event:
  690. mode, ns_ = chunk.event.split(NS_SEP, 1)
  691. ns = tuple(ns_.split(NS_SEP))
  692. else:
  693. mode, ns = chunk.event, ()
  694. if req_single:
  695. yield ns, chunk.data
  696. else:
  697. yield ns, mode, chunk.data
  698. elif req_single:
  699. yield chunk.data
  700. else:
  701. yield chunk
  702. async def astream(
  703. self,
  704. input: dict[str, Any] | Any,
  705. config: RunnableConfig | None = None,
  706. *,
  707. stream_mode: StreamMode | list[StreamMode] | None = None,
  708. interrupt_before: All | Sequence[str] | None = None,
  709. interrupt_after: All | Sequence[str] | None = None,
  710. subgraphs: bool = False,
  711. headers: dict[str, str] | None = None,
  712. params: QueryParamTypes | None = None,
  713. **kwargs: Any,
  714. ) -> AsyncIterator[dict[str, Any] | Any]:
  715. """Create a run and stream the results.
  716. This method calls `POST /threads/{thread_id}/runs/stream` if a `thread_id`
  717. is speciffed in the `configurable` field of the config or
  718. `POST /runs/stream` otherwise.
  719. Args:
  720. input: Input to the graph.
  721. config: A `RunnableConfig` for graph invocation.
  722. stream_mode: Stream mode(s) to use.
  723. interrupt_before: Interrupt the graph before these nodes.
  724. interrupt_after: Interrupt the graph after these nodes.
  725. subgraphs: Stream from subgraphs.
  726. headers: Additional headers to pass to the request.
  727. **kwargs: Additional params to pass to client.runs.stream.
  728. Yields:
  729. The output of the graph.
  730. """
  731. client = self._validate_client()
  732. merged_config = merge_configs(self.config, config)
  733. sanitized_config = self._sanitize_config(merged_config)
  734. stream_modes, requested, req_single, stream = self._get_stream_modes(
  735. stream_mode, config
  736. )
  737. if isinstance(input, Command):
  738. command: CommandSDK | None = cast(CommandSDK, asdict(input))
  739. input = None
  740. else:
  741. command = None
  742. thread_id = sanitized_config.get("configurable", {}).pop("thread_id", None)
  743. async for chunk in client.runs.stream(
  744. thread_id=thread_id,
  745. assistant_id=self.assistant_id,
  746. input=input,
  747. command=command,
  748. config=sanitized_config,
  749. stream_mode=stream_modes,
  750. interrupt_before=interrupt_before,
  751. interrupt_after=interrupt_after,
  752. stream_subgraphs=subgraphs or stream is not None,
  753. if_not_exists="create",
  754. headers=(
  755. _merge_tracing_headers(headers) if self.distributed_tracing else headers
  756. ),
  757. params=params,
  758. **kwargs,
  759. ):
  760. # split mode and ns
  761. if NS_SEP in chunk.event:
  762. mode, ns_ = chunk.event.split(NS_SEP, 1)
  763. ns = tuple(ns_.split(NS_SEP))
  764. else:
  765. mode, ns = chunk.event, ()
  766. # raise ParentCommand exception for command events
  767. if mode == "command" and chunk.data.get("graph") == Command.PARENT:
  768. raise ParentCommand(Command(**chunk.data))
  769. # prepend caller ns (as it is not passed to remote graph)
  770. if caller_ns := (config or {}).get(CONF, {}).get(CONFIG_KEY_CHECKPOINT_NS):
  771. caller_ns = tuple(caller_ns.split(NS_SEP))
  772. ns = caller_ns + ns
  773. # stream to parent stream
  774. if stream is not None and mode in stream.modes:
  775. stream((ns, mode, chunk.data))
  776. # raise interrupt or errors
  777. if chunk.event.startswith("updates"):
  778. if isinstance(chunk.data, dict) and INTERRUPT in chunk.data:
  779. if caller_ns:
  780. raise GraphInterrupt(
  781. [Interrupt(**i) for i in chunk.data[INTERRUPT]]
  782. )
  783. elif chunk.event.startswith("error"):
  784. raise RemoteException(chunk.data)
  785. # filter for what was actually requested
  786. if mode not in requested:
  787. continue
  788. if chunk.event.startswith("messages"):
  789. chunk = chunk._replace(data=tuple(chunk.data)) # type: ignore
  790. # emit chunk
  791. if subgraphs:
  792. if NS_SEP in chunk.event:
  793. mode, ns_ = chunk.event.split(NS_SEP, 1)
  794. ns = tuple(ns_.split(NS_SEP))
  795. else:
  796. mode, ns = chunk.event, ()
  797. if req_single:
  798. yield ns, chunk.data
  799. else:
  800. yield ns, mode, chunk.data
  801. elif req_single:
  802. yield chunk.data
  803. else:
  804. yield chunk
  805. async def astream_events(
  806. self,
  807. input: Any,
  808. config: RunnableConfig | None = None,
  809. *,
  810. version: Literal["v1", "v2"],
  811. include_names: Sequence[All] | None = None,
  812. include_types: Sequence[All] | None = None,
  813. include_tags: Sequence[All] | None = None,
  814. exclude_names: Sequence[All] | None = None,
  815. exclude_types: Sequence[All] | None = None,
  816. exclude_tags: Sequence[All] | None = None,
  817. **kwargs: Any,
  818. ) -> AsyncIterator[dict[str, Any]]:
  819. raise NotImplementedError
  820. def invoke(
  821. self,
  822. input: dict[str, Any] | Any,
  823. config: RunnableConfig | None = None,
  824. *,
  825. interrupt_before: All | Sequence[str] | None = None,
  826. interrupt_after: All | Sequence[str] | None = None,
  827. headers: dict[str, str] | None = None,
  828. params: QueryParamTypes | None = None,
  829. **kwargs: Any,
  830. ) -> dict[str, Any] | Any:
  831. """Create a run, wait until it finishes and return the final state.
  832. Args:
  833. input: Input to the graph.
  834. config: A `RunnableConfig` for graph invocation.
  835. interrupt_before: Interrupt the graph before these nodes.
  836. interrupt_after: Interrupt the graph after these nodes.
  837. headers: Additional headers to pass to the request.
  838. **kwargs: Additional params to pass to RemoteGraph.stream.
  839. Returns:
  840. The output of the graph.
  841. """
  842. for chunk in self.stream(
  843. input,
  844. config=config,
  845. interrupt_before=interrupt_before,
  846. interrupt_after=interrupt_after,
  847. headers=headers,
  848. stream_mode="values",
  849. params=params,
  850. **kwargs,
  851. ):
  852. pass
  853. try:
  854. return chunk
  855. except UnboundLocalError:
  856. logger.warning("No events received from remote graph")
  857. return None
  858. async def ainvoke(
  859. self,
  860. input: dict[str, Any] | Any,
  861. config: RunnableConfig | None = None,
  862. *,
  863. interrupt_before: All | Sequence[str] | None = None,
  864. interrupt_after: All | Sequence[str] | None = None,
  865. headers: dict[str, str] | None = None,
  866. params: QueryParamTypes | None = None,
  867. **kwargs: Any,
  868. ) -> dict[str, Any] | Any:
  869. """Create a run, wait until it finishes and return the final state.
  870. Args:
  871. input: Input to the graph.
  872. config: A `RunnableConfig` for graph invocation.
  873. interrupt_before: Interrupt the graph before these nodes.
  874. interrupt_after: Interrupt the graph after these nodes.
  875. headers: Additional headers to pass to the request.
  876. **kwargs: Additional params to pass to RemoteGraph.astream.
  877. Returns:
  878. The output of the graph.
  879. """
  880. async for chunk in self.astream(
  881. input,
  882. config=config,
  883. interrupt_before=interrupt_before,
  884. interrupt_after=interrupt_after,
  885. headers=headers,
  886. stream_mode="values",
  887. params=params,
  888. **kwargs,
  889. ):
  890. pass
  891. try:
  892. return chunk
  893. except UnboundLocalError:
  894. logger.warning("No events received from remote graph")
  895. return None
  896. def _merge_tracing_headers(headers: dict[str, str] | None) -> dict[str, str] | None:
  897. if rt := ls.get_current_run_tree():
  898. tracing_headers = rt.to_headers()
  899. if headers:
  900. if "baggage" in headers:
  901. tracing_headers["baggage"] = (
  902. f"{headers['baggage']},{tracing_headers['baggage']}"
  903. )
  904. headers.update(tracing_headers)
  905. else:
  906. headers = tracing_headers
  907. return headers