protocol.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from __future__ import annotations
  2. from abc import abstractmethod
  3. from collections.abc import AsyncIterator, Callable, Iterator, Sequence
  4. from typing import Any, Generic, cast
  5. from langchain_core.runnables import Runnable, RunnableConfig
  6. from langchain_core.runnables.graph import Graph as DrawableGraph
  7. from typing_extensions import Self
  8. from langgraph.types import All, Command, StateSnapshot, StateUpdate, StreamMode
  9. from langgraph.typing import ContextT, InputT, OutputT, StateT
  10. __all__ = ("PregelProtocol", "StreamProtocol")
  11. class PregelProtocol(Runnable[InputT, Any], Generic[StateT, ContextT, InputT, OutputT]):
  12. @abstractmethod
  13. def with_config(
  14. self, config: RunnableConfig | None = None, **kwargs: Any
  15. ) -> Self: ...
  16. @abstractmethod
  17. def get_graph(
  18. self,
  19. config: RunnableConfig | None = None,
  20. *,
  21. xray: int | bool = False,
  22. ) -> DrawableGraph: ...
  23. @abstractmethod
  24. async def aget_graph(
  25. self,
  26. config: RunnableConfig | None = None,
  27. *,
  28. xray: int | bool = False,
  29. ) -> DrawableGraph: ...
  30. @abstractmethod
  31. def get_state(
  32. self, config: RunnableConfig, *, subgraphs: bool = False
  33. ) -> StateSnapshot: ...
  34. @abstractmethod
  35. async def aget_state(
  36. self, config: RunnableConfig, *, subgraphs: bool = False
  37. ) -> StateSnapshot: ...
  38. @abstractmethod
  39. def get_state_history(
  40. self,
  41. config: RunnableConfig,
  42. *,
  43. filter: dict[str, Any] | None = None,
  44. before: RunnableConfig | None = None,
  45. limit: int | None = None,
  46. ) -> Iterator[StateSnapshot]: ...
  47. @abstractmethod
  48. def aget_state_history(
  49. self,
  50. config: RunnableConfig,
  51. *,
  52. filter: dict[str, Any] | None = None,
  53. before: RunnableConfig | None = None,
  54. limit: int | None = None,
  55. ) -> AsyncIterator[StateSnapshot]: ...
  56. @abstractmethod
  57. def bulk_update_state(
  58. self,
  59. config: RunnableConfig,
  60. updates: Sequence[Sequence[StateUpdate]],
  61. ) -> RunnableConfig: ...
  62. @abstractmethod
  63. async def abulk_update_state(
  64. self,
  65. config: RunnableConfig,
  66. updates: Sequence[Sequence[StateUpdate]],
  67. ) -> RunnableConfig: ...
  68. @abstractmethod
  69. def update_state(
  70. self,
  71. config: RunnableConfig,
  72. values: dict[str, Any] | Any | None,
  73. as_node: str | None = None,
  74. ) -> RunnableConfig: ...
  75. @abstractmethod
  76. async def aupdate_state(
  77. self,
  78. config: RunnableConfig,
  79. values: dict[str, Any] | Any | None,
  80. as_node: str | None = None,
  81. ) -> RunnableConfig: ...
  82. @abstractmethod
  83. def stream(
  84. self,
  85. input: InputT | Command | None,
  86. config: RunnableConfig | None = None,
  87. *,
  88. context: ContextT | None = None,
  89. stream_mode: StreamMode | list[StreamMode] | None = None,
  90. interrupt_before: All | Sequence[str] | None = None,
  91. interrupt_after: All | Sequence[str] | None = None,
  92. subgraphs: bool = False,
  93. ) -> Iterator[dict[str, Any] | Any]: ...
  94. @abstractmethod
  95. def astream(
  96. self,
  97. input: InputT | Command | None,
  98. config: RunnableConfig | None = None,
  99. *,
  100. context: ContextT | None = None,
  101. stream_mode: StreamMode | list[StreamMode] | None = None,
  102. interrupt_before: All | Sequence[str] | None = None,
  103. interrupt_after: All | Sequence[str] | None = None,
  104. subgraphs: bool = False,
  105. ) -> AsyncIterator[dict[str, Any] | Any]: ...
  106. @abstractmethod
  107. def invoke(
  108. self,
  109. input: InputT | Command | None,
  110. config: RunnableConfig | None = None,
  111. *,
  112. context: ContextT | None = None,
  113. interrupt_before: All | Sequence[str] | None = None,
  114. interrupt_after: All | Sequence[str] | None = None,
  115. ) -> dict[str, Any] | Any: ...
  116. @abstractmethod
  117. async def ainvoke(
  118. self,
  119. input: InputT | Command | None,
  120. config: RunnableConfig | None = None,
  121. *,
  122. context: ContextT | None = None,
  123. interrupt_before: All | Sequence[str] | None = None,
  124. interrupt_after: All | Sequence[str] | None = None,
  125. ) -> dict[str, Any] | Any: ...
  126. StreamChunk = tuple[tuple[str, ...], str, Any]
  127. class StreamProtocol:
  128. __slots__ = ("modes", "__call__")
  129. modes: set[StreamMode]
  130. __call__: Callable[[Self, StreamChunk], None]
  131. def __init__(
  132. self,
  133. __call__: Callable[[StreamChunk], None],
  134. modes: set[StreamMode],
  135. ) -> None:
  136. self.__call__ = cast(Callable[[Self, StreamChunk], None], __call__)
  137. self.modes = modes