runtime.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. from __future__ import annotations
  2. from dataclasses import dataclass, field, replace
  3. from typing import Any, Generic, cast
  4. from langgraph.store.base import BaseStore
  5. from typing_extensions import TypedDict, Unpack
  6. from langgraph._internal._constants import CONF, CONFIG_KEY_RUNTIME
  7. from langgraph.config import get_config
  8. from langgraph.types import _DC_KWARGS, StreamWriter
  9. from langgraph.typing import ContextT
  10. __all__ = ("Runtime", "get_runtime")
  11. def _no_op_stream_writer(_: Any) -> None: ...
  12. class _RuntimeOverrides(TypedDict, Generic[ContextT], total=False):
  13. context: ContextT
  14. store: BaseStore | None
  15. stream_writer: StreamWriter
  16. previous: Any
  17. @dataclass(**_DC_KWARGS)
  18. class Runtime(Generic[ContextT]):
  19. """Convenience class that bundles run-scoped context and other runtime utilities.
  20. !!! version-added "Added in version v0.6.0"
  21. Example:
  22. ```python
  23. from typing import TypedDict
  24. from langgraph.graph import StateGraph
  25. from dataclasses import dataclass
  26. from langgraph.runtime import Runtime
  27. from langgraph.store.memory import InMemoryStore
  28. @dataclass
  29. class Context: # (1)!
  30. user_id: str
  31. class State(TypedDict, total=False):
  32. response: str
  33. store = InMemoryStore() # (2)!
  34. store.put(("users",), "user_123", {"name": "Alice"})
  35. def personalized_greeting(state: State, runtime: Runtime[Context]) -> State:
  36. '''Generate personalized greeting using runtime context and store.'''
  37. user_id = runtime.context.user_id # (3)!
  38. name = "unknown_user"
  39. if runtime.store:
  40. if memory := runtime.store.get(("users",), user_id):
  41. name = memory.value["name"]
  42. response = f"Hello {name}! Nice to see you again."
  43. return {"response": response}
  44. graph = (
  45. StateGraph(state_schema=State, context_schema=Context)
  46. .add_node("personalized_greeting", personalized_greeting)
  47. .set_entry_point("personalized_greeting")
  48. .set_finish_point("personalized_greeting")
  49. .compile(store=store)
  50. )
  51. result = graph.invoke({}, context=Context(user_id="user_123"))
  52. print(result)
  53. # > {'response': 'Hello Alice! Nice to see you again.'}
  54. ```
  55. 1. Define a schema for the runtime context.
  56. 2. Create a store to persist memories and other information.
  57. 3. Use the runtime context to access the `user_id`.
  58. """
  59. context: ContextT = field(default=None) # type: ignore[assignment]
  60. """Static context for the graph run, like `user_id`, `db_conn`, etc.
  61. Can also be thought of as 'run dependencies'."""
  62. store: BaseStore | None = field(default=None)
  63. """Store for the graph run, enabling persistence and memory."""
  64. stream_writer: StreamWriter = field(default=_no_op_stream_writer)
  65. """Function that writes to the custom stream."""
  66. previous: Any = field(default=None)
  67. """The previous return value for the given thread.
  68. Only available with the functional API when a checkpointer is provided.
  69. """
  70. def merge(self, other: Runtime[ContextT]) -> Runtime[ContextT]:
  71. """Merge two runtimes together.
  72. If a value is not provided in the other runtime, the value from the current runtime is used.
  73. """
  74. return Runtime(
  75. context=other.context or self.context,
  76. store=other.store or self.store,
  77. stream_writer=other.stream_writer
  78. if other.stream_writer is not _no_op_stream_writer
  79. else self.stream_writer,
  80. previous=self.previous if other.previous is None else other.previous,
  81. )
  82. def override(
  83. self, **overrides: Unpack[_RuntimeOverrides[ContextT]]
  84. ) -> Runtime[ContextT]:
  85. """Replace the runtime with a new runtime with the given overrides."""
  86. return replace(self, **overrides)
  87. DEFAULT_RUNTIME = Runtime(
  88. context=None,
  89. store=None,
  90. stream_writer=_no_op_stream_writer,
  91. previous=None,
  92. )
  93. def get_runtime(context_schema: type[ContextT] | None = None) -> Runtime[ContextT]:
  94. """Get the runtime for the current graph run.
  95. Args:
  96. context_schema: Optional schema used for type hinting the return type of the runtime.
  97. Returns:
  98. The runtime for the current graph run.
  99. """
  100. # TODO: in an ideal world, we would have a context manager for
  101. # the runtime that's independent of the config. this will follow
  102. # from the removal of the configurable packing
  103. runtime = cast(Runtime[ContextT], get_config()[CONF].get(CONFIG_KEY_RUNTIME))
  104. return runtime