tool_node.py 69 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836
  1. """Tool execution node for LangGraph workflows.
  2. This module provides prebuilt functionality for executing tools in LangGraph.
  3. Tools are functions that models can call to interact with external systems,
  4. APIs, databases, or perform computations.
  5. The module implements design patterns for:
  6. - Parallel execution of multiple tool calls for efficiency
  7. - Robust error handling with customizable error messages
  8. - State injection for tools that need access to graph state
  9. - Store injection for tools that need persistent storage
  10. - Command-based state updates for advanced control flow
  11. Key Components:
  12. - `ToolNode`: Main class for executing tools in LangGraph workflows
  13. - `InjectedState`: Annotation for injecting graph state into tools
  14. - `InjectedStore`: Annotation for injecting persistent store into tools
  15. - `ToolRuntime`: Runtime information for tools, bundling together `state`, `context`,
  16. `config`, `stream_writer`, `tool_call_id`, and `store`
  17. - `tools_condition`: Utility function for conditional routing based on tool calls
  18. Typical Usage:
  19. ```python
  20. from langchain_core.tools import tool
  21. from langchain.tools import ToolNode
  22. @tool
  23. def my_tool(x: int) -> str:
  24. return f"Result: {x}"
  25. tool_node = ToolNode([my_tool])
  26. ```
  27. """
  28. from __future__ import annotations
  29. import asyncio
  30. import inspect
  31. import json
  32. from collections.abc import Awaitable, Callable
  33. from copy import copy, deepcopy
  34. from dataclasses import dataclass, replace
  35. from types import UnionType
  36. from typing import (
  37. TYPE_CHECKING,
  38. Annotated,
  39. Any,
  40. Generic,
  41. Literal,
  42. TypedDict,
  43. Union,
  44. cast,
  45. get_args,
  46. get_origin,
  47. get_type_hints,
  48. )
  49. from langchain_core.messages import (
  50. AIMessage,
  51. AnyMessage,
  52. RemoveMessage,
  53. ToolCall,
  54. ToolMessage,
  55. convert_to_messages,
  56. )
  57. from langchain_core.runnables.config import (
  58. RunnableConfig,
  59. get_config_list,
  60. get_executor_for_config,
  61. )
  62. from langchain_core.tools import BaseTool, InjectedToolArg
  63. from langchain_core.tools import tool as create_tool
  64. from langchain_core.tools.base import (
  65. TOOL_MESSAGE_BLOCK_TYPES,
  66. ToolException,
  67. _DirectlyInjectedToolArg,
  68. get_all_basemodel_annotations,
  69. )
  70. from langgraph._internal._runnable import RunnableCallable
  71. from langgraph.errors import GraphBubbleUp
  72. from langgraph.graph.message import REMOVE_ALL_MESSAGES
  73. from langgraph.store.base import BaseStore # noqa: TC002
  74. from langgraph.types import Command, Send, StreamWriter
  75. from pydantic import BaseModel, ValidationError
  76. from typing_extensions import TypeVar, Unpack
  77. if TYPE_CHECKING:
  78. from collections.abc import Sequence
  79. from langgraph.runtime import Runtime
  80. from pydantic_core import ErrorDetails
  81. # right now we use a dict as the default, can change this to AgentState, but depends
  82. # on if this lives in LangChain or LangGraph... ideally would have some typed
  83. # messages key
  84. StateT = TypeVar("StateT", default=dict)
  85. ContextT = TypeVar("ContextT", default=None)
  86. INVALID_TOOL_NAME_ERROR_TEMPLATE = (
  87. "Error: {requested_tool} is not a valid tool, try one of [{available_tools}]."
  88. )
  89. TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
  90. TOOL_EXECUTION_ERROR_TEMPLATE = (
  91. "Error executing tool '{tool_name}' with kwargs {tool_kwargs} with error:\n"
  92. " {error}\n"
  93. " Please fix the error and try again."
  94. )
  95. TOOL_INVOCATION_ERROR_TEMPLATE = (
  96. "Error invoking tool '{tool_name}' with kwargs {tool_kwargs} with error:\n"
  97. " {error}\n"
  98. " Please fix the error and try again."
  99. )
  100. class _ToolCallRequestOverrides(TypedDict, total=False):
  101. """Possible overrides for ToolCallRequest.override() method."""
  102. tool_call: ToolCall
  103. @dataclass
  104. class ToolCallRequest:
  105. """Tool execution request passed to tool call interceptors.
  106. Attributes:
  107. tool_call: Tool call dict with name, args, and id from model output.
  108. tool: BaseTool instance to be invoked, or None if tool is not
  109. registered with the `ToolNode`. When tool is `None`, interceptors can
  110. handle the request without validation. If the interceptor calls `execute()`,
  111. validation will occur and raise an error for unregistered tools.
  112. state: Agent state (`dict`, `list`, or `BaseModel`).
  113. runtime: LangGraph runtime context (optional, `None` if outside graph).
  114. """
  115. tool_call: ToolCall
  116. tool: BaseTool | None
  117. state: Any
  118. runtime: ToolRuntime
  119. def __setattr__(self, name: str, value: Any) -> None:
  120. """Raise deprecation warning when setting attributes directly.
  121. Direct attribute assignment is deprecated. Use the `override()` method instead.
  122. """
  123. import warnings
  124. # Allow setting attributes during initialization
  125. if not hasattr(self, "__dataclass_fields__") or not hasattr(self, name):
  126. object.__setattr__(self, name, value)
  127. else:
  128. warnings.warn(
  129. f"Setting attribute '{name}' on ToolCallRequest is deprecated. "
  130. "Use the override() method instead to create a new instance with modified values.",
  131. DeprecationWarning,
  132. stacklevel=2,
  133. )
  134. object.__setattr__(self, name, value)
  135. def override(
  136. self, **overrides: Unpack[_ToolCallRequestOverrides]
  137. ) -> ToolCallRequest:
  138. """Replace the request with a new request with the given overrides.
  139. Returns a new `ToolCallRequest` instance with the specified attributes replaced.
  140. This follows an immutable pattern, leaving the original request unchanged.
  141. Args:
  142. **overrides: Keyword arguments for attributes to override. Supported keys:
  143. - tool_call: Tool call dict with name, args, and id
  144. Returns:
  145. New ToolCallRequest instance with specified overrides applied.
  146. Examples:
  147. ```python
  148. # Modify tool call arguments without mutating original
  149. modified_call = {**request.tool_call, "args": {"value": 10}}
  150. new_request = request.override(tool_call=modified_call)
  151. # Override multiple attributes
  152. new_request = request.override(tool_call=modified_call, state=new_state)
  153. ```
  154. """
  155. return replace(self, **overrides)
  156. ToolCallWrapper = Callable[
  157. [ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command]],
  158. ToolMessage | Command,
  159. ]
  160. """Wrapper for tool call execution with multi-call support.
  161. Wrapper receives:
  162. request: ToolCallRequest with tool_call, tool, state, and runtime.
  163. execute: Callable to execute the tool (CAN BE CALLED MULTIPLE TIMES).
  164. Returns:
  165. ToolMessage or Command (the final result).
  166. The execute callable can be invoked multiple times for retry logic,
  167. with potentially modified requests each time. Each call to execute
  168. is independent and stateless.
  169. !!! note
  170. When implementing middleware for `create_agent`, use
  171. `AgentMiddleware.wrap_tool_call` which provides properly typed
  172. state parameter for better type safety.
  173. Examples:
  174. Passthrough (execute once):
  175. def handler(request, execute):
  176. return execute(request)
  177. Modify request before execution:
  178. ```python
  179. def handler(request, execute):
  180. modified_call = {**request.tool_call, "args": {**request.tool_call["args"], "value": request.tool_call["args"]["value"] * 2}}
  181. modified_request = request.override(tool_call=modified_call)
  182. return execute(modified_request)
  183. ```
  184. Retry on error (execute multiple times):
  185. ```python
  186. def handler(request, execute):
  187. for attempt in range(3):
  188. try:
  189. result = execute(request)
  190. if is_valid(result):
  191. return result
  192. except Exception:
  193. if attempt == 2:
  194. raise
  195. return result
  196. ```
  197. Conditional retry based on response:
  198. ```python
  199. def handler(request, execute):
  200. for attempt in range(3):
  201. result = execute(request)
  202. if isinstance(result, ToolMessage) and result.status != "error":
  203. return result
  204. if attempt < 2:
  205. continue
  206. return result
  207. ```
  208. Cache/short-circuit without calling execute:
  209. ```python
  210. def handler(request, execute):
  211. if cached := get_cache(request):
  212. return ToolMessage(content=cached, tool_call_id=request.tool_call["id"])
  213. result = execute(request)
  214. save_cache(request, result)
  215. return result
  216. ```
  217. """
  218. AsyncToolCallWrapper = Callable[
  219. [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
  220. Awaitable[ToolMessage | Command],
  221. ]
  222. """Async wrapper for tool call execution with multi-call support."""
  223. class ToolCallWithContext(TypedDict):
  224. """ToolCall with additional context for graph state.
  225. This is an internal data structure meant to help the `ToolNode` accept
  226. tool calls with additional context (e.g. state) when dispatched using the
  227. Send API.
  228. The Send API is used in create_agent to distribute tool calls in parallel
  229. and support human-in-the-loop workflows where graph execution may be paused
  230. for an indefinite time.
  231. """
  232. tool_call: ToolCall
  233. __type: Literal["tool_call_with_context"]
  234. """Type to parameterize the payload.
  235. Using "__" as a prefix to be defensive against potential name collisions with
  236. regular user state.
  237. """
  238. state: Any
  239. """The state is provided as additional context."""
  240. def msg_content_output(output: Any) -> str | list[dict]:
  241. """Convert tool output to `ToolMessage` content format.
  242. Handles `str`, `list[dict]` (content blocks), and arbitrary objects by attempting
  243. JSON serialization with fallback to str().
  244. Args:
  245. output: Tool execution output of any type.
  246. Returns:
  247. String or list of content blocks suitable for `ToolMessage.content`.
  248. """
  249. if isinstance(output, str) or (
  250. isinstance(output, list)
  251. and all(
  252. isinstance(x, dict) and x.get("type") in TOOL_MESSAGE_BLOCK_TYPES
  253. for x in output
  254. )
  255. ):
  256. return output
  257. # Technically a list of strings is also valid message content, but it's
  258. # not currently well tested that all chat models support this.
  259. # And for backwards compatibility we want to make sure we don't break
  260. # any existing ToolNode usage.
  261. try:
  262. return json.dumps(output, ensure_ascii=False)
  263. except Exception: # noqa: BLE001
  264. return str(output)
  265. class ToolInvocationError(ToolException):
  266. """An error occurred while invoking a tool due to invalid arguments.
  267. This exception is only raised when invoking a tool using the `ToolNode`!
  268. """
  269. def __init__(
  270. self,
  271. tool_name: str,
  272. source: ValidationError,
  273. tool_kwargs: dict[str, Any],
  274. filtered_errors: list[ErrorDetails] | None = None,
  275. ) -> None:
  276. """Initialize the ToolInvocationError.
  277. Args:
  278. tool_name: The name of the tool that failed.
  279. source: The exception that occurred.
  280. tool_kwargs: The keyword arguments that were passed to the tool.
  281. filtered_errors: Optional list of filtered validation errors excluding
  282. injected arguments.
  283. """
  284. # Format error display based on filtered errors if provided
  285. if filtered_errors is not None:
  286. # Manually format the filtered errors without URLs or fancy formatting
  287. error_str_parts = []
  288. for error in filtered_errors:
  289. loc_str = ".".join(str(loc) for loc in error.get("loc", ()))
  290. msg = error.get("msg", "Unknown error")
  291. error_str_parts.append(f"{loc_str}: {msg}")
  292. error_display_str = "\n".join(error_str_parts)
  293. else:
  294. error_display_str = str(source)
  295. self.message = TOOL_INVOCATION_ERROR_TEMPLATE.format(
  296. tool_name=tool_name, tool_kwargs=tool_kwargs, error=error_display_str
  297. )
  298. self.tool_name = tool_name
  299. self.tool_kwargs = tool_kwargs
  300. self.source = source
  301. self.filtered_errors = filtered_errors
  302. super().__init__(self.message)
  303. def _default_handle_tool_errors(e: Exception) -> str:
  304. """Default error handler for tool errors.
  305. If the tool is a tool invocation error, return its message.
  306. Otherwise, raise the error.
  307. """
  308. if isinstance(e, ToolInvocationError):
  309. return e.message
  310. raise e
  311. def _handle_tool_error(
  312. e: Exception,
  313. *,
  314. flag: bool
  315. | str
  316. | Callable[..., str]
  317. | type[Exception]
  318. | tuple[type[Exception], ...],
  319. ) -> str:
  320. """Generate error message content based on exception handling configuration.
  321. This function centralizes error message generation logic, supporting different
  322. error handling strategies configured via the `ToolNode`'s `handle_tool_errors`
  323. parameter.
  324. Args:
  325. e: The exception that occurred during tool execution.
  326. flag: Configuration for how to handle the error. Can be:
  327. - bool: If `True`, use default error template
  328. - str: Use this string as the error message
  329. - Callable: Call this function with the exception to get error message
  330. - tuple: Not used in this context (handled by caller)
  331. Returns:
  332. A string containing the error message to include in the `ToolMessage`.
  333. Raises:
  334. ValueError: If flag is not one of the supported types.
  335. !!! note
  336. The tuple case is handled by the caller through exception type checking,
  337. not by this function directly.
  338. """
  339. if isinstance(flag, (bool, tuple)) or (
  340. isinstance(flag, type) and issubclass(flag, Exception)
  341. ):
  342. content = TOOL_CALL_ERROR_TEMPLATE.format(error=repr(e))
  343. elif isinstance(flag, str):
  344. content = flag
  345. elif callable(flag):
  346. content = flag(e) # type: ignore [assignment, call-arg]
  347. else:
  348. msg = (
  349. f"Got unexpected type of `handle_tool_error`. Expected bool, str "
  350. f"or callable. Received: {flag}"
  351. )
  352. raise ValueError(msg)
  353. return content
  354. def _infer_handled_types(handler: Callable[..., str]) -> tuple[type[Exception], ...]:
  355. """Infer exception types handled by a custom error handler function.
  356. This function analyzes the type annotations of a custom error handler to determine
  357. which exception types it's designed to handle. This enables type-safe error handling
  358. where only specific exceptions are caught and processed by the handler.
  359. Args:
  360. handler: A callable that takes an exception and returns an error message string.
  361. The first parameter (after self/cls if present) should be type-annotated
  362. with the exception type(s) to handle.
  363. Returns:
  364. A tuple of exception types that the handler can process. Returns (Exception,)
  365. if no specific type information is available for backward compatibility.
  366. Raises:
  367. ValueError: If the handler's annotation contains non-Exception types or
  368. if Union types contain non-Exception types.
  369. !!! note
  370. This function supports both single exception types and Union types for
  371. handlers that need to handle multiple exception types differently.
  372. """
  373. sig = inspect.signature(handler)
  374. params = list(sig.parameters.values())
  375. if params:
  376. # If it's a method, the first argument is typically 'self' or 'cls'
  377. if params[0].name in ["self", "cls"] and len(params) == 2:
  378. first_param = params[1]
  379. else:
  380. first_param = params[0]
  381. type_hints = get_type_hints(handler)
  382. if first_param.name in type_hints:
  383. origin = get_origin(first_param.annotation)
  384. if origin in [Union, UnionType]:
  385. args = get_args(first_param.annotation)
  386. if all(issubclass(arg, Exception) for arg in args):
  387. return tuple(args)
  388. msg = (
  389. "All types in the error handler error annotation must be "
  390. "Exception types. For example, "
  391. "`def custom_handler(e: Union[ValueError, TypeError])`. "
  392. f"Got '{first_param.annotation}' instead."
  393. )
  394. raise ValueError(msg)
  395. exception_type = type_hints[first_param.name]
  396. if Exception in exception_type.__mro__:
  397. return (exception_type,)
  398. msg = (
  399. f"Arbitrary types are not supported in the error handler "
  400. f"signature. Please annotate the error with either a "
  401. f"specific Exception type or a union of Exception types. "
  402. "For example, `def custom_handler(e: ValueError)` or "
  403. "`def custom_handler(e: Union[ValueError, TypeError])`. "
  404. f"Got '{exception_type}' instead."
  405. )
  406. raise ValueError(msg)
  407. # If no type information is available, return (Exception,)
  408. # for backwards compatibility.
  409. return (Exception,)
  410. def _filter_validation_errors(
  411. validation_error: ValidationError,
  412. injected_args: _InjectedArgs | None,
  413. ) -> list[ErrorDetails]:
  414. """Filter validation errors to only include LLM-controlled arguments.
  415. When a tool invocation fails validation, only errors for arguments that the LLM
  416. controls should be included in error messages. This ensures the LLM receives
  417. focused, actionable feedback about parameters it can actually fix. System-injected
  418. arguments (state, store, runtime) are filtered out since the LLM has no control
  419. over them.
  420. This function also removes injected argument values from the `input` field in error
  421. details, ensuring that only LLM-provided arguments appear in error messages.
  422. Args:
  423. validation_error: The Pydantic ValidationError raised during tool invocation.
  424. injected_args: The _InjectedArgs structure containing all injected arguments,
  425. or None if there are no injected arguments.
  426. Returns:
  427. List of ErrorDetails containing only errors for LLM-controlled arguments,
  428. with system-injected argument values removed from the input field.
  429. """
  430. # Collect all injected argument names
  431. injected_arg_names: set[str] = set()
  432. if injected_args:
  433. if injected_args.state:
  434. injected_arg_names.update(injected_args.state.keys())
  435. if injected_args.store:
  436. injected_arg_names.add(injected_args.store)
  437. if injected_args.runtime:
  438. injected_arg_names.add(injected_args.runtime)
  439. filtered_errors: list[ErrorDetails] = []
  440. for error in validation_error.errors():
  441. # Check if error location contains any injected argument
  442. # error['loc'] is a tuple like ('field_name',) or ('field_name', 'nested_field')
  443. if error["loc"] and error["loc"][0] not in injected_arg_names:
  444. # Create a copy of the error dict to avoid mutating the original
  445. error_copy: dict[str, Any] = {**error}
  446. # Remove injected arguments from input_value if it's a dict
  447. if isinstance(error_copy.get("input"), dict):
  448. input_dict = error_copy["input"]
  449. input_copy = {
  450. k: v for k, v in input_dict.items() if k not in injected_arg_names
  451. }
  452. error_copy["input"] = input_copy
  453. # Cast is safe because ErrorDetails is a TypedDict compatible with this structure
  454. filtered_errors.append(error_copy) # type: ignore[arg-type]
  455. return filtered_errors
  456. @dataclass
  457. class _InjectedArgs:
  458. """Internal structure for tracking injected arguments for a tool.
  459. This data structure is built once during ToolNode initialization by analyzing
  460. the tool's signature and args schema, then reused during execution for efficient
  461. injection without repeated reflection.
  462. The structure maps from tool parameter names to their injection sources, enabling
  463. the ToolNode to know exactly which arguments need to be injected and where to
  464. get their values from.
  465. Attributes:
  466. state: Mapping from tool parameter names to state field names for injection.
  467. Keys are tool parameter names, values are either:
  468. - str: Name of the state field to extract and inject
  469. - None: Inject the entire state object
  470. Empty dict if no state injection is needed.
  471. store: Name of the tool parameter where the store should be injected,
  472. or None if no store injection is needed.
  473. runtime: Name of the tool parameter where the runtime should be injected,
  474. or None if no runtime injection is needed.
  475. Example:
  476. For a tool with signature:
  477. ```python
  478. def my_tool(
  479. x: int,
  480. messages: Annotated[list, InjectedState("messages")],
  481. full_state: Annotated[dict, InjectedState()],
  482. store: Annotated[BaseStore, InjectedStore()],
  483. runtime: ToolRuntime,
  484. ) -> str:
  485. ...
  486. ```
  487. The resulting `_InjectedArgs` would be:
  488. ```python
  489. _InjectedArgs(
  490. state={
  491. "messages": "messages", # Extract state["messages"]
  492. "full_state": None, # Inject entire state
  493. },
  494. store="store", # Inject into "store" parameter
  495. runtime="runtime", # Inject into "runtime" parameter
  496. )
  497. ```
  498. """
  499. state: dict[str, str | None]
  500. store: str | None
  501. runtime: str | None
  502. class ToolNode(RunnableCallable):
  503. """A node for executing tools in LangGraph workflows.
  504. Handles tool execution patterns including function calls, state injection,
  505. persistent storage, and control flow. Manages parallel execution,
  506. error handling.
  507. Input Formats:
  508. 1. Graph state with `messages` key that has a list of messages:
  509. - Common representation for agentic workflows
  510. - Supports custom messages key via `messages_key` parameter
  511. 2. **Message List**: `[AIMessage(..., tool_calls=[...])]`
  512. - List of messages with tool calls in the last AIMessage
  513. 3. **Direct Tool Calls**: `[{"name": "tool", "args": {...}, "id": "1", "type": "tool_call"}]`
  514. - Bypasses message parsing for direct tool execution
  515. - For programmatic tool invocation and testing
  516. Output Formats:
  517. Output format depends on input type and tool behavior:
  518. **For Regular tools**:
  519. - Dict input → `{"messages": [ToolMessage(...)]}`
  520. - List input → `[ToolMessage(...)]`
  521. **For Command tools**:
  522. - Returns `[Command(...)]` or mixed list with regular tool outputs
  523. - `Command` can update state, trigger navigation, or send messages
  524. Args:
  525. tools: A sequence of tools that can be invoked by this node.
  526. Supports:
  527. - **BaseTool instances**: Tools with schemas and metadata
  528. - **Plain functions**: Automatically converted to tools with inferred schemas
  529. name: The name identifier for this node in the graph. Used for debugging
  530. and visualization.
  531. tags: Optional metadata tags to associate with the node for filtering
  532. and organization.
  533. handle_tool_errors: Configuration for error handling during tool execution.
  534. Supports multiple strategies:
  535. - `True`: Catch all errors and return a `ToolMessage` with the default
  536. error template containing the exception details.
  537. - `str`: Catch all errors and return a `ToolMessage` with this custom
  538. error message string.
  539. - `type[Exception]`: Only catch exceptions with the specified type and
  540. return the default error message for it.
  541. - `tuple[type[Exception], ...]`: Only catch exceptions with the specified
  542. types and return default error messages for them.
  543. - `Callable[..., str]`: Catch exceptions matching the callable's signature
  544. and return the string result of calling it with the exception.
  545. - `False`: Disable error handling entirely, allowing exceptions to
  546. propagate.
  547. Defaults to a callable that:
  548. - Catches tool invocation errors (due to invalid arguments provided by the
  549. model) and returns a descriptive error message
  550. - Ignores tool execution errors (they will be re-raised)
  551. messages_key: The key in the state dictionary that contains the message list.
  552. This same key will be used for the output `ToolMessage` objects.
  553. Allows custom state schemas with different message field names.
  554. Examples:
  555. Basic usage:
  556. ```python
  557. from langchain.tools import ToolNode
  558. from langchain_core.tools import tool
  559. @tool
  560. def calculator(a: int, b: int) -> int:
  561. \"\"\"Add two numbers.\"\"\"
  562. return a + b
  563. tool_node = ToolNode([calculator])
  564. ```
  565. State injection:
  566. ```python
  567. from typing_extensions import Annotated
  568. from langchain.tools import InjectedState
  569. @tool
  570. def context_tool(query: str, state: Annotated[dict, InjectedState]) -> str:
  571. \"\"\"Some tool that uses state.\"\"\"
  572. return f"Query: {query}, Messages: {len(state['messages'])}"
  573. tool_node = ToolNode([context_tool])
  574. ```
  575. Error handling:
  576. ```python
  577. def handle_errors(e: ValueError) -> str:
  578. return "Invalid input provided"
  579. tool_node = ToolNode([my_tool], handle_tool_errors=handle_errors)
  580. ```
  581. """ # noqa: E501
  582. name: str = "tools"
  583. def __init__(
  584. self,
  585. tools: Sequence[BaseTool | Callable],
  586. *,
  587. name: str = "tools",
  588. tags: list[str] | None = None,
  589. handle_tool_errors: bool
  590. | str
  591. | Callable[..., str]
  592. | type[Exception]
  593. | tuple[type[Exception], ...] = _default_handle_tool_errors,
  594. messages_key: str = "messages",
  595. wrap_tool_call: ToolCallWrapper | None = None,
  596. awrap_tool_call: AsyncToolCallWrapper | None = None,
  597. ) -> None:
  598. """Initialize `ToolNode` with tools and configuration.
  599. Args:
  600. tools: Sequence of tools to make available for execution.
  601. name: Node name for graph identification.
  602. tags: Optional metadata tags.
  603. handle_tool_errors: Error handling configuration.
  604. messages_key: State key containing messages.
  605. wrap_tool_call: Sync wrapper function to intercept tool execution. Receives
  606. ToolCallRequest and execute callable, returns ToolMessage or Command.
  607. Enables retries, caching, request modification, and control flow.
  608. awrap_tool_call: Async wrapper function to intercept tool execution.
  609. If not provided, falls back to wrap_tool_call for async execution.
  610. """
  611. super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False)
  612. self._tools_by_name: dict[str, BaseTool] = {}
  613. self._injected_args: dict[str, _InjectedArgs] = {}
  614. self._handle_tool_errors = handle_tool_errors
  615. self._messages_key = messages_key
  616. self._wrap_tool_call = wrap_tool_call
  617. self._awrap_tool_call = awrap_tool_call
  618. for tool in tools:
  619. if not isinstance(tool, BaseTool):
  620. tool_ = create_tool(cast("type[BaseTool]", tool))
  621. else:
  622. tool_ = tool
  623. self._tools_by_name[tool_.name] = tool_
  624. # Build injected args mapping once during initialization in a single pass
  625. self._injected_args[tool_.name] = _get_all_injected_args(tool_)
  626. @property
  627. def tools_by_name(self) -> dict[str, BaseTool]:
  628. """Mapping from tool name to BaseTool instance."""
  629. return self._tools_by_name
  630. def _func(
  631. self,
  632. input: list[AnyMessage] | dict[str, Any] | BaseModel,
  633. config: RunnableConfig,
  634. runtime: Runtime,
  635. ) -> Any:
  636. tool_calls, input_type = self._parse_input(input)
  637. config_list = get_config_list(config, len(tool_calls))
  638. # Construct ToolRuntime instances at the top level for each tool call
  639. tool_runtimes = []
  640. for call, cfg in zip(tool_calls, config_list, strict=False):
  641. state = self._extract_state(input)
  642. tool_runtime = ToolRuntime(
  643. state=state,
  644. tool_call_id=call["id"],
  645. config=cfg,
  646. context=runtime.context,
  647. store=runtime.store,
  648. stream_writer=runtime.stream_writer,
  649. )
  650. tool_runtimes.append(tool_runtime)
  651. # Pass original tool calls without injection
  652. input_types = [input_type] * len(tool_calls)
  653. with get_executor_for_config(config) as executor:
  654. outputs = list(
  655. executor.map(self._run_one, tool_calls, input_types, tool_runtimes)
  656. )
  657. return self._combine_tool_outputs(outputs, input_type)
  658. async def _afunc(
  659. self,
  660. input: list[AnyMessage] | dict[str, Any] | BaseModel,
  661. config: RunnableConfig,
  662. runtime: Runtime,
  663. ) -> Any:
  664. tool_calls, input_type = self._parse_input(input)
  665. config_list = get_config_list(config, len(tool_calls))
  666. # Construct ToolRuntime instances at the top level for each tool call
  667. tool_runtimes = []
  668. for call, cfg in zip(tool_calls, config_list, strict=False):
  669. state = self._extract_state(input)
  670. tool_runtime = ToolRuntime(
  671. state=state,
  672. tool_call_id=call["id"],
  673. config=cfg,
  674. context=runtime.context,
  675. store=runtime.store,
  676. stream_writer=runtime.stream_writer,
  677. )
  678. tool_runtimes.append(tool_runtime)
  679. # Pass original tool calls without injection
  680. coros = []
  681. for call, tool_runtime in zip(tool_calls, tool_runtimes, strict=False):
  682. coros.append(self._arun_one(call, input_type, tool_runtime)) # type: ignore[arg-type]
  683. outputs = await asyncio.gather(*coros)
  684. return self._combine_tool_outputs(outputs, input_type)
  685. def _combine_tool_outputs(
  686. self,
  687. outputs: list[ToolMessage | Command],
  688. input_type: Literal["list", "dict", "tool_calls"],
  689. ) -> list[Command | list[ToolMessage] | dict[str, list[ToolMessage]]]:
  690. # preserve existing behavior for non-command tool outputs for backwards
  691. # compatibility
  692. if not any(isinstance(output, Command) for output in outputs):
  693. # TypedDict, pydantic, dataclass, etc. should all be able to load from dict
  694. return outputs if input_type == "list" else {self._messages_key: outputs}
  695. # LangGraph will automatically handle list of Command and non-command node
  696. # updates
  697. combined_outputs: list[
  698. Command | list[ToolMessage] | dict[str, list[ToolMessage]]
  699. ] = []
  700. # combine all parent commands with goto into a single parent command
  701. parent_command: Command | None = None
  702. for output in outputs:
  703. if isinstance(output, Command):
  704. if (
  705. output.graph is Command.PARENT
  706. and isinstance(output.goto, list)
  707. and all(isinstance(send, Send) for send in output.goto)
  708. ):
  709. if parent_command:
  710. parent_command = replace(
  711. parent_command,
  712. goto=cast("list[Send]", parent_command.goto) + output.goto,
  713. )
  714. else:
  715. parent_command = Command(graph=Command.PARENT, goto=output.goto)
  716. else:
  717. combined_outputs.append(output)
  718. else:
  719. combined_outputs.append(
  720. [output] if input_type == "list" else {self._messages_key: [output]}
  721. )
  722. if parent_command:
  723. combined_outputs.append(parent_command)
  724. return combined_outputs
  725. def _execute_tool_sync(
  726. self,
  727. request: ToolCallRequest,
  728. input_type: Literal["list", "dict", "tool_calls"],
  729. config: RunnableConfig,
  730. ) -> ToolMessage | Command:
  731. """Execute tool call with configured error handling.
  732. Args:
  733. request: Tool execution request.
  734. input_type: Input format.
  735. config: Runnable configuration.
  736. Returns:
  737. ToolMessage or Command.
  738. Raises:
  739. Exception: If tool fails and handle_tool_errors is False.
  740. """
  741. call = request.tool_call
  742. tool = request.tool
  743. # Validate tool exists when we actually need to execute it
  744. if tool is None:
  745. if invalid_tool_message := self._validate_tool_call(call):
  746. return invalid_tool_message
  747. # This should never happen if validation works correctly
  748. msg = f"Tool {call['name']} is not registered with ToolNode"
  749. raise TypeError(msg)
  750. # Inject state, store, and runtime right before invocation
  751. injected_call = self._inject_tool_args(call, request.runtime)
  752. call_args = {**injected_call, "type": "tool_call"}
  753. try:
  754. try:
  755. response = tool.invoke(call_args, config)
  756. except ValidationError as exc:
  757. # Filter out errors for injected arguments
  758. injected = self._injected_args.get(call["name"])
  759. filtered_errors = _filter_validation_errors(exc, injected)
  760. # Use original call["args"] without injected values for error reporting
  761. raise ToolInvocationError(
  762. call["name"], exc, call["args"], filtered_errors
  763. ) from exc
  764. # GraphInterrupt is a special exception that will always be raised.
  765. # It can be triggered in the following scenarios,
  766. # Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation
  767. # most commonly:
  768. # (1) a GraphInterrupt is raised inside a tool
  769. # (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool
  770. # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph
  771. # called as a tool
  772. # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
  773. except GraphBubbleUp:
  774. raise
  775. except Exception as e:
  776. # Determine which exception types are handled
  777. handled_types: tuple[type[Exception], ...]
  778. if isinstance(self._handle_tool_errors, type) and issubclass(
  779. self._handle_tool_errors, Exception
  780. ):
  781. handled_types = (self._handle_tool_errors,)
  782. elif isinstance(self._handle_tool_errors, tuple):
  783. handled_types = self._handle_tool_errors
  784. elif callable(self._handle_tool_errors) and not isinstance(
  785. self._handle_tool_errors, type
  786. ):
  787. handled_types = _infer_handled_types(self._handle_tool_errors)
  788. else:
  789. # default behavior is catching all exceptions
  790. handled_types = (Exception,)
  791. # Check if this error should be handled
  792. if not self._handle_tool_errors or not isinstance(e, handled_types):
  793. raise
  794. # Error is handled - create error ToolMessage
  795. content = _handle_tool_error(e, flag=self._handle_tool_errors)
  796. return ToolMessage(
  797. content=content,
  798. name=call["name"],
  799. tool_call_id=call["id"],
  800. status="error",
  801. )
  802. # Process successful response
  803. if isinstance(response, Command):
  804. # Validate Command before returning to handler
  805. return self._validate_tool_command(response, request.tool_call, input_type)
  806. if isinstance(response, ToolMessage):
  807. response.content = cast("str | list", msg_content_output(response.content))
  808. return response
  809. msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
  810. raise TypeError(msg)
  811. def _run_one(
  812. self,
  813. call: ToolCall,
  814. input_type: Literal["list", "dict", "tool_calls"],
  815. tool_runtime: ToolRuntime,
  816. ) -> ToolMessage | Command:
  817. """Execute single tool call with wrap_tool_call wrapper if configured.
  818. Args:
  819. call: Tool call dict.
  820. input_type: Input format.
  821. tool_runtime: Tool runtime.
  822. Returns:
  823. ToolMessage or Command.
  824. """
  825. # Validation is deferred to _execute_tool_sync to allow interceptors
  826. # to short-circuit requests for unregistered tools
  827. tool = self.tools_by_name.get(call["name"])
  828. # Create the tool request with state and runtime
  829. tool_request = ToolCallRequest(
  830. tool_call=call,
  831. tool=tool,
  832. state=tool_runtime.state,
  833. runtime=tool_runtime,
  834. )
  835. config = tool_runtime.config
  836. if self._wrap_tool_call is None:
  837. # No wrapper - execute directly
  838. return self._execute_tool_sync(tool_request, input_type, config)
  839. # Define execute callable that can be called multiple times
  840. def execute(req: ToolCallRequest) -> ToolMessage | Command:
  841. """Execute tool with given request. Can be called multiple times."""
  842. return self._execute_tool_sync(req, input_type, config)
  843. # Call wrapper with request and execute callable
  844. try:
  845. return self._wrap_tool_call(tool_request, execute)
  846. except Exception as e:
  847. # Wrapper threw an exception
  848. if not self._handle_tool_errors:
  849. raise
  850. # Convert to error message
  851. content = _handle_tool_error(e, flag=self._handle_tool_errors)
  852. return ToolMessage(
  853. content=content,
  854. name=tool_request.tool_call["name"],
  855. tool_call_id=tool_request.tool_call["id"],
  856. status="error",
  857. )
  858. async def _execute_tool_async(
  859. self,
  860. request: ToolCallRequest,
  861. input_type: Literal["list", "dict", "tool_calls"],
  862. config: RunnableConfig,
  863. ) -> ToolMessage | Command:
  864. """Execute tool call asynchronously with configured error handling.
  865. Args:
  866. request: Tool execution request.
  867. input_type: Input format.
  868. config: Runnable configuration.
  869. Returns:
  870. ToolMessage or Command.
  871. Raises:
  872. Exception: If tool fails and handle_tool_errors is False.
  873. """
  874. call = request.tool_call
  875. tool = request.tool
  876. # Validate tool exists when we actually need to execute it
  877. if tool is None:
  878. if invalid_tool_message := self._validate_tool_call(call):
  879. return invalid_tool_message
  880. # This should never happen if validation works correctly
  881. msg = f"Tool {call['name']} is not registered with ToolNode"
  882. raise TypeError(msg)
  883. # Inject state, store, and runtime right before invocation
  884. injected_call = self._inject_tool_args(call, request.runtime)
  885. call_args = {**injected_call, "type": "tool_call"}
  886. try:
  887. try:
  888. response = await tool.ainvoke(call_args, config)
  889. except ValidationError as exc:
  890. # Filter out errors for injected arguments
  891. injected = self._injected_args.get(call["name"])
  892. filtered_errors = _filter_validation_errors(exc, injected)
  893. # Use original call["args"] without injected values for error reporting
  894. raise ToolInvocationError(
  895. call["name"], exc, call["args"], filtered_errors
  896. ) from exc
  897. # GraphInterrupt is a special exception that will always be raised.
  898. # It can be triggered in the following scenarios,
  899. # Where GraphInterrupt(GraphBubbleUp) is raised from an `interrupt` invocation
  900. # most commonly:
  901. # (1) a GraphInterrupt is raised inside a tool
  902. # (2) a GraphInterrupt is raised inside a graph node for a graph called as a tool
  903. # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph
  904. # called as a tool
  905. # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
  906. except GraphBubbleUp:
  907. raise
  908. except Exception as e:
  909. # Determine which exception types are handled
  910. handled_types: tuple[type[Exception], ...]
  911. if isinstance(self._handle_tool_errors, type) and issubclass(
  912. self._handle_tool_errors, Exception
  913. ):
  914. handled_types = (self._handle_tool_errors,)
  915. elif isinstance(self._handle_tool_errors, tuple):
  916. handled_types = self._handle_tool_errors
  917. elif callable(self._handle_tool_errors) and not isinstance(
  918. self._handle_tool_errors, type
  919. ):
  920. handled_types = _infer_handled_types(self._handle_tool_errors)
  921. else:
  922. # default behavior is catching all exceptions
  923. handled_types = (Exception,)
  924. # Check if this error should be handled
  925. if not self._handle_tool_errors or not isinstance(e, handled_types):
  926. raise
  927. # Error is handled - create error ToolMessage
  928. content = _handle_tool_error(e, flag=self._handle_tool_errors)
  929. return ToolMessage(
  930. content=content,
  931. name=call["name"],
  932. tool_call_id=call["id"],
  933. status="error",
  934. )
  935. # Process successful response
  936. if isinstance(response, Command):
  937. # Validate Command before returning to handler
  938. return self._validate_tool_command(response, request.tool_call, input_type)
  939. if isinstance(response, ToolMessage):
  940. response.content = cast("str | list", msg_content_output(response.content))
  941. return response
  942. msg = f"Tool {call['name']} returned unexpected type: {type(response)}"
  943. raise TypeError(msg)
  944. async def _arun_one(
  945. self,
  946. call: ToolCall,
  947. input_type: Literal["list", "dict", "tool_calls"],
  948. tool_runtime: ToolRuntime,
  949. ) -> ToolMessage | Command:
  950. """Execute single tool call asynchronously with awrap_tool_call wrapper if configured.
  951. Args:
  952. call: Tool call dict.
  953. input_type: Input format.
  954. tool_runtime: Tool runtime.
  955. Returns:
  956. ToolMessage or Command.
  957. """
  958. # Validation is deferred to _execute_tool_async to allow interceptors
  959. # to short-circuit requests for unregistered tools
  960. tool = self.tools_by_name.get(call["name"])
  961. # Create the tool request with state and runtime
  962. tool_request = ToolCallRequest(
  963. tool_call=call,
  964. tool=tool,
  965. state=tool_runtime.state,
  966. runtime=tool_runtime,
  967. )
  968. config = tool_runtime.config
  969. if self._awrap_tool_call is None and self._wrap_tool_call is None:
  970. # No wrapper - execute directly
  971. return await self._execute_tool_async(tool_request, input_type, config)
  972. # Define async execute callable that can be called multiple times
  973. async def execute(req: ToolCallRequest) -> ToolMessage | Command:
  974. """Execute tool with given request. Can be called multiple times."""
  975. return await self._execute_tool_async(req, input_type, config)
  976. def _sync_execute(req: ToolCallRequest) -> ToolMessage | Command:
  977. """Sync execute fallback for sync wrapper."""
  978. return self._execute_tool_sync(req, input_type, config)
  979. # Call wrapper with request and execute callable
  980. try:
  981. if self._awrap_tool_call is not None:
  982. return await self._awrap_tool_call(tool_request, execute)
  983. # None check was performed above already
  984. self._wrap_tool_call = cast("ToolCallWrapper", self._wrap_tool_call)
  985. return self._wrap_tool_call(tool_request, _sync_execute)
  986. except Exception as e:
  987. # Wrapper threw an exception
  988. if not self._handle_tool_errors:
  989. raise
  990. # Convert to error message
  991. content = _handle_tool_error(e, flag=self._handle_tool_errors)
  992. return ToolMessage(
  993. content=content,
  994. name=tool_request.tool_call["name"],
  995. tool_call_id=tool_request.tool_call["id"],
  996. status="error",
  997. )
  998. def _parse_input(
  999. self,
  1000. input: list[AnyMessage] | dict[str, Any] | BaseModel,
  1001. ) -> tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]:
  1002. input_type: Literal["list", "dict", "tool_calls"]
  1003. if isinstance(input, list):
  1004. if isinstance(input[-1], dict) and input[-1].get("type") == "tool_call":
  1005. input_type = "tool_calls"
  1006. tool_calls = cast("list[ToolCall]", input)
  1007. return tool_calls, input_type
  1008. input_type = "list"
  1009. messages = input
  1010. elif (
  1011. isinstance(input, dict) and input.get("__type") == "tool_call_with_context"
  1012. ):
  1013. # Handle ToolCallWithContext from Send API
  1014. # mypy will not be able to type narrow correctly since the signature
  1015. # for input contains dict[str, Any]. We'd need to narrow dict[str, Any]
  1016. # before we can apply correct typing.
  1017. input_with_ctx = cast("ToolCallWithContext", input)
  1018. input_type = "tool_calls"
  1019. return [input_with_ctx["tool_call"]], input_type
  1020. elif isinstance(input, dict) and (
  1021. messages := input.get(self._messages_key, [])
  1022. ):
  1023. input_type = "dict"
  1024. elif messages := getattr(input, self._messages_key, []):
  1025. # Assume dataclass-like state that can coerce from dict
  1026. input_type = "dict"
  1027. else:
  1028. msg = "No message found in input"
  1029. raise ValueError(msg)
  1030. try:
  1031. latest_ai_message = next(
  1032. m for m in reversed(messages) if isinstance(m, AIMessage)
  1033. )
  1034. except StopIteration:
  1035. msg = "No AIMessage found in input"
  1036. raise ValueError(msg)
  1037. tool_calls = list(latest_ai_message.tool_calls)
  1038. return tool_calls, input_type
  1039. def _validate_tool_call(self, call: ToolCall) -> ToolMessage | None:
  1040. requested_tool = call["name"]
  1041. if requested_tool not in self.tools_by_name:
  1042. all_tool_names = list(self.tools_by_name.keys())
  1043. content = INVALID_TOOL_NAME_ERROR_TEMPLATE.format(
  1044. requested_tool=requested_tool,
  1045. available_tools=", ".join(all_tool_names),
  1046. )
  1047. return ToolMessage(
  1048. content, name=requested_tool, tool_call_id=call["id"], status="error"
  1049. )
  1050. return None
  1051. def _extract_state(
  1052. self, input: list[AnyMessage] | dict[str, Any] | BaseModel
  1053. ) -> list[AnyMessage] | dict[str, Any] | BaseModel:
  1054. """Extract state from input, handling ToolCallWithContext if present.
  1055. Args:
  1056. input: The input which may be raw state or ToolCallWithContext.
  1057. Returns:
  1058. The actual state to pass to wrap_tool_call wrappers.
  1059. """
  1060. if isinstance(input, dict) and input.get("__type") == "tool_call_with_context":
  1061. return input["state"]
  1062. return input
  1063. def _inject_tool_args(
  1064. self,
  1065. tool_call: ToolCall,
  1066. tool_runtime: ToolRuntime,
  1067. ) -> ToolCall:
  1068. """Inject graph state, store, and runtime into tool call arguments.
  1069. This is an internal method that enables tools to access graph context that
  1070. should not be controlled by the model. Tools can declare dependencies on graph
  1071. state, persistent storage, or runtime context using InjectedState, InjectedStore,
  1072. and ToolRuntime annotations. This method automatically identifies these
  1073. dependencies and injects the appropriate values.
  1074. The injection process preserves the original tool call structure while adding
  1075. the necessary context arguments. This allows tools to be both model-callable
  1076. and context-aware without exposing internal state management to the model.
  1077. Args:
  1078. tool_call: The tool call dictionary to augment with injected arguments.
  1079. Must contain 'name', 'args', 'id', and 'type' fields.
  1080. tool_runtime: The ToolRuntime instance containing all runtime context
  1081. (state, config, store, context, stream_writer) to inject into tools.
  1082. Returns:
  1083. A new ToolCall dictionary with the same structure as the input but with
  1084. additional arguments injected based on the tool's annotation requirements.
  1085. Raises:
  1086. ValueError: If a tool requires store injection but no store is provided,
  1087. or if state injection requirements cannot be satisfied.
  1088. !!! note
  1089. This method is called automatically during tool execution. It should not
  1090. be called from outside the `ToolNode`.
  1091. """
  1092. if tool_call["name"] not in self.tools_by_name:
  1093. return tool_call
  1094. injected = self._injected_args.get(tool_call["name"])
  1095. if not injected:
  1096. return tool_call
  1097. tool_call_copy: ToolCall = copy(tool_call)
  1098. injected_args = {}
  1099. # Inject state
  1100. if injected.state:
  1101. state = tool_runtime.state
  1102. # Handle list state by converting to dict
  1103. if isinstance(state, list):
  1104. required_fields = list(injected.state.values())
  1105. if (
  1106. len(required_fields) == 1
  1107. and required_fields[0] == self._messages_key
  1108. ) or required_fields[0] is None:
  1109. state = {self._messages_key: state}
  1110. else:
  1111. err_msg = (
  1112. f"Invalid input to ToolNode. Tool {tool_call['name']} requires "
  1113. f"graph state dict as input."
  1114. )
  1115. if any(state_field for state_field in injected.state.values()):
  1116. required_fields_str = ", ".join(f for f in required_fields if f)
  1117. err_msg += (
  1118. f" State should contain fields {required_fields_str}."
  1119. )
  1120. raise ValueError(err_msg)
  1121. # Extract state values
  1122. if isinstance(state, dict):
  1123. for tool_arg, state_field in injected.state.items():
  1124. injected_args[tool_arg] = (
  1125. state[state_field] if state_field else state
  1126. )
  1127. else:
  1128. for tool_arg, state_field in injected.state.items():
  1129. injected_args[tool_arg] = (
  1130. getattr(state, state_field) if state_field else state
  1131. )
  1132. # Inject store
  1133. if injected.store:
  1134. if tool_runtime.store is None:
  1135. msg = (
  1136. "Cannot inject store into tools with InjectedStore annotations - "
  1137. "please compile your graph with a store."
  1138. )
  1139. raise ValueError(msg)
  1140. injected_args[injected.store] = tool_runtime.store
  1141. # Inject runtime
  1142. if injected.runtime:
  1143. injected_args[injected.runtime] = tool_runtime
  1144. tool_call_copy["args"] = {**tool_call_copy["args"], **injected_args}
  1145. return tool_call_copy
  1146. def _validate_tool_command(
  1147. self,
  1148. command: Command,
  1149. call: ToolCall,
  1150. input_type: Literal["list", "dict", "tool_calls"],
  1151. ) -> Command:
  1152. if isinstance(command.update, dict):
  1153. # input type is dict when ToolNode is invoked with a dict input
  1154. # (e.g. {"messages": [AIMessage(..., tool_calls=[...])]})
  1155. if input_type not in ("dict", "tool_calls"):
  1156. msg = (
  1157. "Tools can provide a dict in Command.update only when using dict "
  1158. f"with '{self._messages_key}' key as ToolNode input, "
  1159. f"got: {command.update} for tool '{call['name']}'"
  1160. )
  1161. raise ValueError(msg)
  1162. updated_command = deepcopy(command)
  1163. state_update = cast("dict[str, Any]", updated_command.update) or {}
  1164. messages_update = state_update.get(self._messages_key, [])
  1165. elif isinstance(command.update, list):
  1166. # Input type is list when ToolNode is invoked with a list input
  1167. # (e.g. [AIMessage(..., tool_calls=[...])])
  1168. if input_type != "list":
  1169. msg = (
  1170. "Tools can provide a list of messages in Command.update "
  1171. "only when using list of messages as ToolNode input, "
  1172. f"got: {command.update} for tool '{call['name']}'"
  1173. )
  1174. raise ValueError(msg)
  1175. updated_command = deepcopy(command)
  1176. messages_update = updated_command.update
  1177. else:
  1178. return command
  1179. # convert to message objects if updates are in a dict format
  1180. messages_update = convert_to_messages(messages_update)
  1181. # no validation needed if all messages are being removed
  1182. if messages_update == [RemoveMessage(id=REMOVE_ALL_MESSAGES)]:
  1183. return updated_command
  1184. has_matching_tool_message = False
  1185. for message in messages_update:
  1186. if not isinstance(message, ToolMessage):
  1187. continue
  1188. if message.tool_call_id == call["id"]:
  1189. message.name = call["name"]
  1190. has_matching_tool_message = True
  1191. # validate that we always have a ToolMessage matching the tool call in
  1192. # Command.update if command is sent to the CURRENT graph
  1193. if updated_command.graph is None and not has_matching_tool_message:
  1194. example_update = (
  1195. '`Command(update={"messages": '
  1196. '[ToolMessage("Success", tool_call_id=tool_call_id), ...]}, ...)`'
  1197. if input_type == "dict"
  1198. else "`Command(update="
  1199. '[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`'
  1200. )
  1201. msg = (
  1202. "Expected to have a matching ToolMessage in Command.update "
  1203. f"for tool '{call['name']}', got: {messages_update}. "
  1204. "Every tool call (LLM requesting to call a tool) "
  1205. "in the message history MUST have a corresponding ToolMessage. "
  1206. f"You can fix it by modifying the tool to return {example_update}."
  1207. )
  1208. raise ValueError(msg)
  1209. return updated_command
  1210. def tools_condition(
  1211. state: list[AnyMessage] | dict[str, Any] | BaseModel,
  1212. messages_key: str = "messages",
  1213. ) -> Literal["tools", "__end__"]:
  1214. """Conditional routing function for tool-calling workflows.
  1215. This utility function implements the standard conditional logic for ReAct-style
  1216. agents: if the last `AIMessage` contains tool calls, route to the tool execution
  1217. node; otherwise, end the workflow. This pattern is fundamental to most tool-calling
  1218. agent architectures.
  1219. The function handles multiple state formats commonly used in LangGraph applications,
  1220. making it flexible for different graph designs while maintaining consistent behavior.
  1221. Args:
  1222. state: The current graph state to examine for tool calls. Supported formats:
  1223. - Dictionary containing a messages key (for `StateGraph`)
  1224. - `BaseModel` instance with a messages attribute
  1225. messages_key: The key or attribute name containing the message list in the state.
  1226. This allows customization for graphs using different state schemas.
  1227. Returns:
  1228. Either `'tools'` if tool calls are present in the last `AIMessage`, or `'__end__'`
  1229. to terminate the workflow. These are the standard routing destinations for
  1230. tool-calling conditional edges.
  1231. Raises:
  1232. ValueError: If no messages can be found in the provided state format.
  1233. Example:
  1234. Basic usage in a ReAct agent:
  1235. ```python
  1236. from langgraph.graph import StateGraph
  1237. from langchain.tools import ToolNode
  1238. from langchain.tools.tool_node import tools_condition
  1239. from typing_extensions import TypedDict
  1240. class State(TypedDict):
  1241. messages: list
  1242. graph = StateGraph(State)
  1243. graph.add_node("llm", call_model)
  1244. graph.add_node("tools", ToolNode([my_tool]))
  1245. graph.add_conditional_edges(
  1246. "llm",
  1247. tools_condition, # Routes to "tools" or "__end__"
  1248. {"tools": "tools", "__end__": "__end__"},
  1249. )
  1250. ```
  1251. Custom messages key:
  1252. ```python
  1253. def custom_condition(state):
  1254. return tools_condition(state, messages_key="chat_history")
  1255. ```
  1256. !!! note
  1257. This function is designed to work seamlessly with `ToolNode` and standard
  1258. LangGraph patterns. It expects the last message to be an `AIMessage` when
  1259. tool calls are present, which is the standard output format for tool-calling
  1260. language models.
  1261. """
  1262. if isinstance(state, list):
  1263. ai_message = state[-1]
  1264. elif (isinstance(state, dict) and (messages := state.get(messages_key, []))) or (
  1265. messages := getattr(state, messages_key, [])
  1266. ):
  1267. ai_message = messages[-1]
  1268. else:
  1269. msg = f"No messages found in input state to tool_edge: {state}"
  1270. raise ValueError(msg)
  1271. if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
  1272. return "tools"
  1273. return "__end__"
  1274. @dataclass
  1275. class ToolRuntime(_DirectlyInjectedToolArg, Generic[ContextT, StateT]):
  1276. """Runtime context automatically injected into tools.
  1277. When a tool function has a parameter named `tool_runtime` with type hint
  1278. `ToolRuntime`, the tool execution system will automatically inject an instance
  1279. containing:
  1280. - `state`: The current graph state
  1281. - `tool_call_id`: The ID of the current tool call
  1282. - `config`: `RunnableConfig` for the current execution
  1283. - `context`: Runtime context (from langgraph `Runtime`)
  1284. - `store`: `BaseStore` instance for persistent storage (from langgraph `Runtime`)
  1285. - `stream_writer`: `StreamWriter` for streaming output (from langgraph `Runtime`)
  1286. No `Annotated` wrapper is needed - just use `runtime: ToolRuntime`
  1287. as a parameter.
  1288. Example:
  1289. ```python
  1290. from langchain_core.tools import tool
  1291. from langchain.tools import ToolRuntime
  1292. @tool
  1293. def my_tool(x: int, runtime: ToolRuntime) -> str:
  1294. \"\"\"Tool that accesses runtime context.\"\"\"
  1295. # Access state
  1296. messages = tool_runtime.state["messages"]
  1297. # Access tool_call_id
  1298. print(f"Tool call ID: {tool_runtime.tool_call_id}")
  1299. # Access config
  1300. print(f"Run ID: {tool_runtime.config.get('run_id')}")
  1301. # Access runtime context
  1302. user_id = tool_runtime.context.get("user_id")
  1303. # Access store
  1304. tool_runtime.store.put(("metrics",), "count", 1)
  1305. # Stream output
  1306. tool_runtime.stream_writer.write("Processing...")
  1307. return f"Processed {x}"
  1308. ```
  1309. !!! note
  1310. This is a marker class used for type checking and detection.
  1311. The actual runtime object will be constructed during tool execution.
  1312. """
  1313. state: StateT
  1314. context: ContextT
  1315. config: RunnableConfig
  1316. stream_writer: StreamWriter
  1317. tool_call_id: str | None
  1318. store: BaseStore | None
  1319. class InjectedState(InjectedToolArg):
  1320. """Annotation for injecting graph state into tool arguments.
  1321. This annotation enables tools to access graph state without exposing state
  1322. management details to the language model. Tools annotated with `InjectedState`
  1323. receive state data automatically during execution while remaining invisible
  1324. to the model's tool-calling interface.
  1325. Args:
  1326. field: Optional key to extract from the state dictionary. If `None`, the entire
  1327. state is injected. If specified, only that field's value is injected.
  1328. This allows tools to request specific state components rather than
  1329. processing the full state structure.
  1330. Example:
  1331. ```python
  1332. from typing import List
  1333. from typing_extensions import Annotated, TypedDict
  1334. from langchain_core.messages import BaseMessage, AIMessage
  1335. from langchain.tools import InjectedState, ToolNode, tool
  1336. class AgentState(TypedDict):
  1337. messages: List[BaseMessage]
  1338. foo: str
  1339. @tool
  1340. def state_tool(x: int, state: Annotated[dict, InjectedState]) -> str:
  1341. '''Do something with state.'''
  1342. if len(state["messages"]) > 2:
  1343. return state["foo"] + str(x)
  1344. else:
  1345. return "not enough messages"
  1346. @tool
  1347. def foo_tool(x: int, foo: Annotated[str, InjectedState("foo")]) -> str:
  1348. '''Do something else with state.'''
  1349. return foo + str(x + 1)
  1350. node = ToolNode([state_tool, foo_tool])
  1351. tool_call1 = {"name": "state_tool", "args": {"x": 1}, "id": "1", "type": "tool_call"}
  1352. tool_call2 = {"name": "foo_tool", "args": {"x": 1}, "id": "2", "type": "tool_call"}
  1353. state = {
  1354. "messages": [AIMessage("", tool_calls=[tool_call1, tool_call2])],
  1355. "foo": "bar",
  1356. }
  1357. node.invoke(state)
  1358. ```
  1359. ```python
  1360. [
  1361. ToolMessage(content="not enough messages", name="state_tool", tool_call_id="1"),
  1362. ToolMessage(content="bar2", name="foo_tool", tool_call_id="2"),
  1363. ]
  1364. ```
  1365. !!! note
  1366. - `InjectedState` arguments are automatically excluded from tool schemas
  1367. presented to language models
  1368. - `ToolNode` handles the injection process during execution
  1369. - Tools can mix regular arguments (controlled by the model) with injected
  1370. arguments (controlled by the system)
  1371. - State injection occurs after the model generates tool calls but before
  1372. tool execution
  1373. """
  1374. def __init__(self, field: str | None = None) -> None:
  1375. """Initialize the `InjectedState` annotation."""
  1376. self.field = field
  1377. class InjectedStore(InjectedToolArg):
  1378. """Annotation for injecting persistent store into tool arguments.
  1379. This annotation enables tools to access LangGraph's persistent storage system
  1380. without exposing storage details to the language model. Tools annotated with
  1381. `InjectedStore` receive the store instance automatically during execution while
  1382. remaining invisible to the model's tool-calling interface.
  1383. The store provides persistent, cross-session data storage that tools can use
  1384. for maintaining context, user preferences, or any other data that needs to
  1385. persist beyond individual workflow executions.
  1386. !!! warning
  1387. `InjectedStore` annotation requires `langchain-core >= 0.3.8`
  1388. Example:
  1389. ```python
  1390. from typing_extensions import Annotated
  1391. from langgraph.store.memory import InMemoryStore
  1392. from langchain.tools import InjectedStore, ToolNode, tool
  1393. @tool
  1394. def save_preference(
  1395. key: str,
  1396. value: str,
  1397. store: Annotated[Any, InjectedStore()]
  1398. ) -> str:
  1399. \"\"\"Save user preference to persistent storage.\"\"\"
  1400. store.put(("preferences",), key, value)
  1401. return f"Saved {key} = {value}"
  1402. @tool
  1403. def get_preference(
  1404. key: str,
  1405. store: Annotated[Any, InjectedStore()]
  1406. ) -> str:
  1407. \"\"\"Retrieve user preference from persistent storage.\"\"\"
  1408. result = store.get(("preferences",), key)
  1409. return result.value if result else "Not found"
  1410. ```
  1411. Usage with `ToolNode` and graph compilation:
  1412. ```python
  1413. from langgraph.graph import StateGraph
  1414. from langgraph.store.memory import InMemoryStore
  1415. store = InMemoryStore()
  1416. tool_node = ToolNode([save_preference, get_preference])
  1417. graph = StateGraph(State)
  1418. graph.add_node("tools", tool_node)
  1419. compiled_graph = graph.compile(store=store) # Store is injected automatically
  1420. ```
  1421. Cross-session persistence:
  1422. ```python
  1423. # First session
  1424. result1 = graph.invoke({"messages": [HumanMessage("Save my favorite color as blue")]})
  1425. # Later session - data persists
  1426. result2 = graph.invoke({"messages": [HumanMessage("What's my favorite color?")]})
  1427. ```
  1428. !!! note
  1429. - `InjectedStore` arguments are automatically excluded from tool schemas
  1430. presented to language models
  1431. - The store instance is automatically injected by `ToolNode` during execution
  1432. - Tools can access namespaced storage using the store's get/put methods
  1433. - Store injection requires the graph to be compiled with a store instance
  1434. - Multiple tools can share the same store instance for data consistency
  1435. """
  1436. def _is_injection(
  1437. type_arg: Any,
  1438. injection_type: type[InjectedState | InjectedStore | ToolRuntime],
  1439. ) -> bool:
  1440. """Check if a type argument represents an injection annotation.
  1441. This utility function determines whether a type annotation indicates that
  1442. an argument should be injected with state or store data. It handles both
  1443. direct annotations and nested annotations within Union or Annotated types.
  1444. Args:
  1445. type_arg: The type argument to check for injection annotations.
  1446. injection_type: The injection type to look for (InjectedState or InjectedStore).
  1447. Returns:
  1448. True if the type argument contains the specified injection annotation.
  1449. """
  1450. if isinstance(type_arg, injection_type) or (
  1451. isinstance(type_arg, type) and issubclass(type_arg, injection_type)
  1452. ):
  1453. return True
  1454. origin_ = get_origin(type_arg)
  1455. if origin_ is Union or origin_ is Annotated:
  1456. return any(_is_injection(ta, injection_type) for ta in get_args(type_arg))
  1457. return False
  1458. def _get_injection_from_type(
  1459. type_: Any, injection_type: type[InjectedState | InjectedStore | ToolRuntime]
  1460. ) -> Any | None:
  1461. """Extract injection instance from a type annotation.
  1462. Args:
  1463. type_: The type annotation to check.
  1464. injection_type: The injection type to look for.
  1465. Returns:
  1466. The injection instance if found, True if injection marker found without instance, None otherwise.
  1467. """
  1468. type_args = get_args(type_)
  1469. matches = [arg for arg in type_args if _is_injection(arg, injection_type)]
  1470. if len(matches) > 1:
  1471. msg = (
  1472. f"A tool argument should not be annotated with {injection_type.__name__} "
  1473. f"more than once. Found: {matches}"
  1474. )
  1475. raise ValueError(msg)
  1476. if len(matches) == 1:
  1477. return matches[0]
  1478. elif _is_injection(type_, injection_type):
  1479. return True
  1480. return None
  1481. def _get_all_injected_args(tool: BaseTool) -> _InjectedArgs:
  1482. """Extract all injected arguments from tool in a single pass.
  1483. This function analyzes both the tool's input schema and function signature
  1484. to identify all arguments that should be injected (state, store, runtime).
  1485. Args:
  1486. tool: The tool to analyze for injection requirements.
  1487. Returns:
  1488. _InjectedArgs structure containing all detected injections.
  1489. """
  1490. # Get annotations from both schema and function signature
  1491. full_schema = tool.get_input_schema()
  1492. schema_annotations = get_all_basemodel_annotations(full_schema)
  1493. func = getattr(tool, "func", None) or getattr(tool, "coroutine", None)
  1494. func_annotations = get_type_hints(func, include_extras=True) if func else {}
  1495. # Combine both annotation sources, preferring schema annotations
  1496. # In the future, we might want to add more restrictions here...
  1497. all_annotations = {**func_annotations, **schema_annotations}
  1498. # Track injected args
  1499. state_args: dict[str, str | None] = {}
  1500. store_arg: str | None = None
  1501. runtime_arg: str | None = None
  1502. for name, type_ in all_annotations.items():
  1503. # Check for runtime (special case: parameter named "runtime")
  1504. if name == "runtime":
  1505. runtime_arg = name
  1506. # Check for InjectedState
  1507. if state_inj := _get_injection_from_type(type_, InjectedState):
  1508. if isinstance(state_inj, InjectedState) and state_inj.field:
  1509. state_args[name] = state_inj.field
  1510. else:
  1511. state_args[name] = None
  1512. # Check for InjectedStore
  1513. if _get_injection_from_type(type_, InjectedStore):
  1514. store_arg = name
  1515. # Check for ToolRuntime
  1516. if _get_injection_from_type(type_, ToolRuntime):
  1517. runtime_arg = name
  1518. return _InjectedArgs(
  1519. state=state_args,
  1520. store=store_arg,
  1521. runtime=runtime_arg,
  1522. )