_read.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. from __future__ import annotations
  2. from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
  3. from functools import cached_property
  4. from typing import (
  5. Any,
  6. )
  7. from langchain_core.runnables import Runnable, RunnableConfig
  8. from langgraph._internal._config import merge_configs
  9. from langgraph._internal._constants import CONF, CONFIG_KEY_READ
  10. from langgraph._internal._runnable import RunnableCallable, RunnableSeq
  11. from langgraph.pregel._utils import find_subgraph_pregel
  12. from langgraph.pregel._write import ChannelWrite
  13. from langgraph.pregel.protocol import PregelProtocol
  14. from langgraph.types import CachePolicy, RetryPolicy
  15. READ_TYPE = Callable[[str | Sequence[str], bool], Any | dict[str, Any]]
  16. INPUT_CACHE_KEY_TYPE = tuple[Callable[..., Any], tuple[str, ...]]
  17. class ChannelRead(RunnableCallable):
  18. """Implements the logic for reading state from CONFIG_KEY_READ.
  19. Usable both as a runnable as well as a static method to call imperatively."""
  20. channel: str | list[str]
  21. fresh: bool = False
  22. mapper: Callable[[Any], Any] | None = None
  23. def __init__(
  24. self,
  25. channel: str | list[str],
  26. *,
  27. fresh: bool = False,
  28. mapper: Callable[[Any], Any] | None = None,
  29. tags: list[str] | None = None,
  30. ) -> None:
  31. super().__init__(
  32. func=self._read,
  33. afunc=self._aread,
  34. tags=tags,
  35. name=None,
  36. trace=False,
  37. )
  38. self.fresh = fresh
  39. self.mapper = mapper
  40. self.channel = channel
  41. def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str:
  42. if name:
  43. pass
  44. elif isinstance(self.channel, str):
  45. name = f"ChannelRead<{self.channel}>"
  46. else:
  47. name = f"ChannelRead<{','.join(self.channel)}>"
  48. return super().get_name(suffix, name=name)
  49. def _read(self, _: Any, config: RunnableConfig) -> Any:
  50. return self.do_read(
  51. config, select=self.channel, fresh=self.fresh, mapper=self.mapper
  52. )
  53. async def _aread(self, _: Any, config: RunnableConfig) -> Any:
  54. return self.do_read(
  55. config, select=self.channel, fresh=self.fresh, mapper=self.mapper
  56. )
  57. @staticmethod
  58. def do_read(
  59. config: RunnableConfig,
  60. *,
  61. select: str | list[str],
  62. fresh: bool = False,
  63. mapper: Callable[[Any], Any] | None = None,
  64. ) -> Any:
  65. try:
  66. read: READ_TYPE = config[CONF][CONFIG_KEY_READ]
  67. except KeyError:
  68. raise RuntimeError(
  69. "Not configured with a read function"
  70. "Make sure to call in the context of a Pregel process"
  71. )
  72. if mapper:
  73. return mapper(read(select, fresh))
  74. else:
  75. return read(select, fresh)
  76. DEFAULT_BOUND = RunnableCallable(lambda input: input)
  77. class PregelNode:
  78. """A node in a Pregel graph. This won't be invoked as a runnable by the graph
  79. itself, but instead acts as a container for the components necessary to make
  80. a PregelExecutableTask for a node."""
  81. channels: str | list[str]
  82. """The channels that will be passed as input to `bound`.
  83. If a str, the node will be invoked with its value if it isn't empty.
  84. If a list, the node will be invoked with a dict of those channels' values."""
  85. triggers: list[str]
  86. """If any of these channels is written to, this node will be triggered in
  87. the next step."""
  88. mapper: Callable[[Any], Any] | None
  89. """A function to transform the input before passing it to `bound`."""
  90. writers: list[Runnable]
  91. """A list of writers that will be executed after `bound`, responsible for
  92. taking the output of `bound` and writing it to the appropriate channels."""
  93. bound: Runnable[Any, Any]
  94. """The main logic of the node. This will be invoked with the input from
  95. `channels`."""
  96. retry_policy: Sequence[RetryPolicy] | None
  97. """The retry policies to use when invoking the node."""
  98. cache_policy: CachePolicy | None
  99. """The cache policy to use when invoking the node."""
  100. tags: Sequence[str] | None
  101. """Tags to attach to the node for tracing."""
  102. metadata: Mapping[str, Any] | None
  103. """Metadata to attach to the node for tracing."""
  104. subgraphs: Sequence[PregelProtocol]
  105. """Subgraphs used by the node."""
  106. def __init__(
  107. self,
  108. *,
  109. channels: str | list[str],
  110. triggers: Sequence[str],
  111. mapper: Callable[[Any], Any] | None = None,
  112. writers: list[Runnable] | None = None,
  113. tags: list[str] | None = None,
  114. metadata: Mapping[str, Any] | None = None,
  115. bound: Runnable[Any, Any] | None = None,
  116. retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
  117. cache_policy: CachePolicy | None = None,
  118. subgraphs: Sequence[PregelProtocol] | None = None,
  119. ) -> None:
  120. self.channels = channels
  121. self.triggers = list(triggers)
  122. self.mapper = mapper
  123. self.writers = writers or []
  124. self.bound = bound if bound is not None else DEFAULT_BOUND
  125. self.cache_policy = cache_policy
  126. if isinstance(retry_policy, RetryPolicy):
  127. self.retry_policy = (retry_policy,)
  128. else:
  129. self.retry_policy = retry_policy
  130. self.tags = tags
  131. self.metadata = metadata
  132. if subgraphs is not None:
  133. self.subgraphs = subgraphs
  134. elif self.bound is not DEFAULT_BOUND:
  135. try:
  136. subgraph = find_subgraph_pregel(self.bound)
  137. except Exception:
  138. subgraph = None
  139. if subgraph:
  140. self.subgraphs = [subgraph]
  141. else:
  142. self.subgraphs = []
  143. else:
  144. self.subgraphs = []
  145. def copy(self, update: dict[str, Any]) -> PregelNode:
  146. attrs = {**self.__dict__, **update}
  147. # Drop the cached properties
  148. attrs.pop("flat_writers", None)
  149. attrs.pop("node", None)
  150. attrs.pop("input_cache_key", None)
  151. return PregelNode(**attrs)
  152. @cached_property
  153. def flat_writers(self) -> list[Runnable]:
  154. """Get writers with optimizations applied. Dedupes consecutive ChannelWrites."""
  155. writers = self.writers.copy()
  156. while (
  157. len(writers) > 1
  158. and isinstance(writers[-1], ChannelWrite)
  159. and isinstance(writers[-2], ChannelWrite)
  160. ):
  161. # we can combine writes if they are consecutive
  162. # careful to not modify the original writers list or ChannelWrite
  163. writers[-2] = ChannelWrite(
  164. writes=writers[-2].writes + writers[-1].writes,
  165. )
  166. writers.pop()
  167. return writers
  168. @cached_property
  169. def node(self) -> Runnable[Any, Any] | None:
  170. """Get a runnable that combines `bound` and `writers`."""
  171. writers = self.flat_writers
  172. if self.bound is DEFAULT_BOUND and not writers:
  173. return None
  174. elif self.bound is DEFAULT_BOUND and len(writers) == 1:
  175. return writers[0]
  176. elif self.bound is DEFAULT_BOUND:
  177. return RunnableSeq(*writers)
  178. elif writers:
  179. return RunnableSeq(self.bound, *writers)
  180. else:
  181. return self.bound
  182. @cached_property
  183. def input_cache_key(self) -> INPUT_CACHE_KEY_TYPE:
  184. """Get a cache key for the input to the node.
  185. This is used to avoid calculating the same input multiple times."""
  186. return (
  187. self.mapper,
  188. tuple(self.channels)
  189. if isinstance(self.channels, list)
  190. else (self.channels,),
  191. )
  192. def invoke(
  193. self,
  194. input: Any,
  195. config: RunnableConfig | None = None,
  196. **kwargs: Any | None,
  197. ) -> Any:
  198. self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
  199. return self.bound.invoke(
  200. input,
  201. merge_configs(self_config, config),
  202. **kwargs,
  203. )
  204. async def ainvoke(
  205. self,
  206. input: Any,
  207. config: RunnableConfig | None = None,
  208. **kwargs: Any | None,
  209. ) -> Any:
  210. self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
  211. return await self.bound.ainvoke(
  212. input,
  213. merge_configs(self_config, config),
  214. **kwargs,
  215. )
  216. def stream(
  217. self,
  218. input: Any,
  219. config: RunnableConfig | None = None,
  220. **kwargs: Any | None,
  221. ) -> Iterator[Any]:
  222. self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
  223. yield from self.bound.stream(
  224. input,
  225. merge_configs(self_config, config),
  226. **kwargs,
  227. )
  228. async def astream(
  229. self,
  230. input: Any,
  231. config: RunnableConfig | None = None,
  232. **kwargs: Any | None,
  233. ) -> AsyncIterator[Any]:
  234. self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
  235. async for item in self.bound.astream(
  236. input,
  237. merge_configs(self_config, config),
  238. **kwargs,
  239. ):
  240. yield item