_node.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from __future__ import annotations
  2. from collections.abc import Sequence
  3. from dataclasses import dataclass
  4. from typing import Any, Generic, Protocol, TypeAlias
  5. from langchain_core.runnables import Runnable, RunnableConfig
  6. from langgraph.store.base import BaseStore
  7. from langgraph._internal._typing import EMPTY_SEQ
  8. from langgraph.runtime import Runtime
  9. from langgraph.types import CachePolicy, RetryPolicy, StreamWriter
  10. from langgraph.typing import ContextT, NodeInputT, NodeInputT_contra
  11. class _Node(Protocol[NodeInputT_contra]):
  12. def __call__(self, state: NodeInputT_contra) -> Any: ...
  13. class _NodeWithConfig(Protocol[NodeInputT_contra]):
  14. def __call__(self, state: NodeInputT_contra, config: RunnableConfig) -> Any: ...
  15. class _NodeWithWriter(Protocol[NodeInputT_contra]):
  16. def __call__(self, state: NodeInputT_contra, *, writer: StreamWriter) -> Any: ...
  17. class _NodeWithStore(Protocol[NodeInputT_contra]):
  18. def __call__(self, state: NodeInputT_contra, *, store: BaseStore) -> Any: ...
  19. class _NodeWithWriterStore(Protocol[NodeInputT_contra]):
  20. def __call__(
  21. self, state: NodeInputT_contra, *, writer: StreamWriter, store: BaseStore
  22. ) -> Any: ...
  23. class _NodeWithConfigWriter(Protocol[NodeInputT_contra]):
  24. def __call__(
  25. self, state: NodeInputT_contra, *, config: RunnableConfig, writer: StreamWriter
  26. ) -> Any: ...
  27. class _NodeWithConfigStore(Protocol[NodeInputT_contra]):
  28. def __call__(
  29. self, state: NodeInputT_contra, *, config: RunnableConfig, store: BaseStore
  30. ) -> Any: ...
  31. class _NodeWithConfigWriterStore(Protocol[NodeInputT_contra]):
  32. def __call__(
  33. self,
  34. state: NodeInputT_contra,
  35. *,
  36. config: RunnableConfig,
  37. writer: StreamWriter,
  38. store: BaseStore,
  39. ) -> Any: ...
  40. class _NodeWithRuntime(Protocol[NodeInputT_contra, ContextT]):
  41. def __call__(
  42. self, state: NodeInputT_contra, *, runtime: Runtime[ContextT]
  43. ) -> Any: ...
  44. # TODO: we probably don't want to explicitly support the config / store signatures once
  45. # we move to adding a context arg. Maybe what we do is we add support for kwargs with param spec
  46. # this is purely for typing purposes though, so can easily change in the coming weeks.
  47. StateNode: TypeAlias = (
  48. _Node[NodeInputT]
  49. | _NodeWithConfig[NodeInputT]
  50. | _NodeWithWriter[NodeInputT]
  51. | _NodeWithStore[NodeInputT]
  52. | _NodeWithWriterStore[NodeInputT]
  53. | _NodeWithConfigWriter[NodeInputT]
  54. | _NodeWithConfigStore[NodeInputT]
  55. | _NodeWithConfigWriterStore[NodeInputT]
  56. | _NodeWithRuntime[NodeInputT, ContextT]
  57. | Runnable[NodeInputT, Any]
  58. )
  59. @dataclass(slots=True)
  60. class StateNodeSpec(Generic[NodeInputT, ContextT]):
  61. runnable: StateNode[NodeInputT, ContextT]
  62. metadata: dict[str, Any] | None
  63. input_schema: type[NodeInputT]
  64. retry_policy: RetryPolicy | Sequence[RetryPolicy] | None
  65. cache_policy: CachePolicy | None
  66. ends: tuple[str, ...] | dict[str, str] | None = EMPTY_SEQ
  67. defer: bool = False