| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682 |
- """Agent factory for creating agents with middleware support."""
- from __future__ import annotations
- import itertools
- from typing import (
- TYPE_CHECKING,
- Annotated,
- Any,
- cast,
- get_args,
- get_origin,
- get_type_hints,
- )
- from langchain_core.language_models.chat_models import BaseChatModel
- from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
- from langchain_core.tools import BaseTool
- from langgraph._internal._runnable import RunnableCallable
- from langgraph.constants import END, START
- from langgraph.graph.state import StateGraph
- from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode
- from langgraph.runtime import Runtime # noqa: TC002
- from langgraph.types import Command, Send
- from langgraph.typing import ContextT # noqa: TC002
- from typing_extensions import NotRequired, Required, TypedDict
- from langchain.agents.middleware.types import (
- AgentMiddleware,
- AgentState,
- JumpTo,
- ModelRequest,
- ModelResponse,
- OmitFromSchema,
- ResponseT,
- StateT_co,
- _InputAgentState,
- _OutputAgentState,
- )
- from langchain.agents.structured_output import (
- AutoStrategy,
- MultipleStructuredOutputsError,
- OutputToolBinding,
- ProviderStrategy,
- ProviderStrategyBinding,
- ResponseFormat,
- StructuredOutputError,
- StructuredOutputValidationError,
- ToolStrategy,
- )
- from langchain.chat_models import init_chat_model
- if TYPE_CHECKING:
- from collections.abc import Awaitable, Callable, Sequence
- from langchain_core.runnables import Runnable
- from langgraph.cache.base import BaseCache
- from langgraph.graph.state import CompiledStateGraph
- from langgraph.store.base import BaseStore
- from langgraph.types import Checkpointer
- from langchain.agents.middleware.types import ToolCallRequest, ToolCallWrapper
- STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
- FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [
- # if model profile data are not available, these models are assumed to support
- # structured output
- "grok",
- "gpt-5",
- "gpt-4.1",
- "gpt-4o",
- "gpt-oss",
- "o3-pro",
- "o3-mini",
- ]
- def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
- """Normalize middleware return value to ModelResponse."""
- if isinstance(result, AIMessage):
- return ModelResponse(result=[result], structured_response=None)
- return result
- def _chain_model_call_handlers(
- handlers: Sequence[
- Callable[
- [ModelRequest, Callable[[ModelRequest], ModelResponse]],
- ModelResponse | AIMessage,
- ]
- ],
- ) -> (
- Callable[
- [ModelRequest, Callable[[ModelRequest], ModelResponse]],
- ModelResponse,
- ]
- | None
- ):
- """Compose multiple wrap_model_call handlers into single middleware stack.
- Composes handlers so first in list becomes outermost layer. Each handler
- receives a handler callback to execute inner layers.
- Args:
- handlers: List of handlers. First handler wraps all others.
- Returns:
- Composed handler, or `None` if handlers empty.
- Example:
- ```python
- # handlers=[auth, retry] means: auth wraps retry
- # Flow: auth calls retry, retry calls base handler
- def auth(req, state, runtime, handler):
- try:
- return handler(req)
- except UnauthorizedError:
- refresh_token()
- return handler(req)
- def retry(req, state, runtime, handler):
- for attempt in range(3):
- try:
- return handler(req)
- except Exception:
- if attempt == 2:
- raise
- handler = _chain_model_call_handlers([auth, retry])
- ```
- """
- if not handlers:
- return None
- if len(handlers) == 1:
- # Single handler - wrap to normalize output
- single_handler = handlers[0]
- def normalized_single(
- request: ModelRequest,
- handler: Callable[[ModelRequest], ModelResponse],
- ) -> ModelResponse:
- result = single_handler(request, handler)
- return _normalize_to_model_response(result)
- return normalized_single
- def compose_two(
- outer: Callable[
- [ModelRequest, Callable[[ModelRequest], ModelResponse]],
- ModelResponse | AIMessage,
- ],
- inner: Callable[
- [ModelRequest, Callable[[ModelRequest], ModelResponse]],
- ModelResponse | AIMessage,
- ],
- ) -> Callable[
- [ModelRequest, Callable[[ModelRequest], ModelResponse]],
- ModelResponse,
- ]:
- """Compose two handlers where outer wraps inner."""
- def composed(
- request: ModelRequest,
- handler: Callable[[ModelRequest], ModelResponse],
- ) -> ModelResponse:
- # Create a wrapper that calls inner with the base handler and normalizes
- def inner_handler(req: ModelRequest) -> ModelResponse:
- inner_result = inner(req, handler)
- return _normalize_to_model_response(inner_result)
- # Call outer with the wrapped inner as its handler and normalize
- outer_result = outer(request, inner_handler)
- return _normalize_to_model_response(outer_result)
- return composed
- # Compose right-to-left: outer(inner(innermost(handler)))
- result = handlers[-1]
- for handler in reversed(handlers[:-1]):
- result = compose_two(handler, result)
- # Wrap to ensure final return type is exactly ModelResponse
- def final_normalized(
- request: ModelRequest,
- handler: Callable[[ModelRequest], ModelResponse],
- ) -> ModelResponse:
- # result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
- final_result = result(request, handler)
- return _normalize_to_model_response(final_result)
- return final_normalized
- def _chain_async_model_call_handlers(
- handlers: Sequence[
- Callable[
- [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
- Awaitable[ModelResponse | AIMessage],
- ]
- ],
- ) -> (
- Callable[
- [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
- Awaitable[ModelResponse],
- ]
- | None
- ):
- """Compose multiple async `wrap_model_call` handlers into single middleware stack.
- Args:
- handlers: List of async handlers. First handler wraps all others.
- Returns:
- Composed async handler, or `None` if handlers empty.
- """
- if not handlers:
- return None
- if len(handlers) == 1:
- # Single handler - wrap to normalize output
- single_handler = handlers[0]
- async def normalized_single(
- request: ModelRequest,
- handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
- ) -> ModelResponse:
- result = await single_handler(request, handler)
- return _normalize_to_model_response(result)
- return normalized_single
- def compose_two(
- outer: Callable[
- [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
- Awaitable[ModelResponse | AIMessage],
- ],
- inner: Callable[
- [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
- Awaitable[ModelResponse | AIMessage],
- ],
- ) -> Callable[
- [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
- Awaitable[ModelResponse],
- ]:
- """Compose two async handlers where outer wraps inner."""
- async def composed(
- request: ModelRequest,
- handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
- ) -> ModelResponse:
- # Create a wrapper that calls inner with the base handler and normalizes
- async def inner_handler(req: ModelRequest) -> ModelResponse:
- inner_result = await inner(req, handler)
- return _normalize_to_model_response(inner_result)
- # Call outer with the wrapped inner as its handler and normalize
- outer_result = await outer(request, inner_handler)
- return _normalize_to_model_response(outer_result)
- return composed
- # Compose right-to-left: outer(inner(innermost(handler)))
- result = handlers[-1]
- for handler in reversed(handlers[:-1]):
- result = compose_two(handler, result)
- # Wrap to ensure final return type is exactly ModelResponse
- async def final_normalized(
- request: ModelRequest,
- handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
- ) -> ModelResponse:
- # result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
- final_result = await result(request, handler)
- return _normalize_to_model_response(final_result)
- return final_normalized
- def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
- """Resolve schema by merging schemas and optionally respecting `OmitFromSchema` annotations.
- Args:
- schemas: List of schema types to merge
- schema_name: Name for the generated `TypedDict`
- omit_flag: If specified, omit fields with this flag set (`'input'` or
- `'output'`)
- """
- all_annotations = {}
- for schema in schemas:
- hints = get_type_hints(schema, include_extras=True)
- for field_name, field_type in hints.items():
- should_omit = False
- if omit_flag:
- # Check for omission in the annotation metadata
- metadata = _extract_metadata(field_type)
- for meta in metadata:
- if isinstance(meta, OmitFromSchema) and getattr(meta, omit_flag) is True:
- should_omit = True
- break
- if not should_omit:
- all_annotations[field_name] = field_type
- return TypedDict(schema_name, all_annotations) # type: ignore[operator]
- def _extract_metadata(type_: type) -> list:
- """Extract metadata from a field type, handling Required/NotRequired and Annotated wrappers."""
- # Handle Required[Annotated[...]] or NotRequired[Annotated[...]]
- if get_origin(type_) in (Required, NotRequired):
- inner_type = get_args(type_)[0]
- if get_origin(inner_type) is Annotated:
- return list(get_args(inner_type)[1:])
- # Handle direct Annotated[...]
- elif get_origin(type_) is Annotated:
- return list(get_args(type_)[1:])
- return []
- def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> list[JumpTo]:
- """Get the `can_jump_to` list from either sync or async hook methods.
- Args:
- middleware: The middleware instance to inspect.
- hook_name: The name of the hook (`'before_model'` or `'after_model'`).
- Returns:
- List of jump destinations, or empty list if not configured.
- """
- # Get the base class method for comparison
- base_sync_method = getattr(AgentMiddleware, hook_name, None)
- base_async_method = getattr(AgentMiddleware, f"a{hook_name}", None)
- # Try sync method first - only if it's overridden from base class
- sync_method = getattr(middleware.__class__, hook_name, None)
- if (
- sync_method
- and sync_method is not base_sync_method
- and hasattr(sync_method, "__can_jump_to__")
- ):
- return sync_method.__can_jump_to__
- # Try async method - only if it's overridden from base class
- async_method = getattr(middleware.__class__, f"a{hook_name}", None)
- if (
- async_method
- and async_method is not base_async_method
- and hasattr(async_method, "__can_jump_to__")
- ):
- return async_method.__can_jump_to__
- return []
- def _supports_provider_strategy(model: str | BaseChatModel, tools: list | None = None) -> bool:
- """Check if a model supports provider-specific structured output.
- Args:
- model: Model name string or `BaseChatModel` instance.
- tools: Optional list of tools provided to the agent. Needed because some models
- don't support structured output together with tool calling.
- Returns:
- `True` if the model supports provider-specific structured output, `False` otherwise.
- """
- model_name: str | None = None
- if isinstance(model, str):
- model_name = model
- elif isinstance(model, BaseChatModel):
- model_name = (
- getattr(model, "model_name", None)
- or getattr(model, "model", None)
- or getattr(model, "model_id", "")
- )
- model_profile = model.profile
- if (
- model_profile is not None
- and model_profile.get("structured_output")
- # We make an exception for Gemini models, which currently do not support
- # simultaneous tool use with structured output
- and not (tools and isinstance(model_name, str) and "gemini" in model_name.lower())
- ):
- return True
- return (
- any(part in model_name.lower() for part in FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT)
- if model_name
- else False
- )
- def _handle_structured_output_error(
- exception: Exception,
- response_format: ResponseFormat,
- ) -> tuple[bool, str]:
- """Handle structured output error. Returns `(should_retry, retry_tool_message)`."""
- if not isinstance(response_format, ToolStrategy):
- return False, ""
- handle_errors = response_format.handle_errors
- if handle_errors is False:
- return False, ""
- if handle_errors is True:
- return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
- if isinstance(handle_errors, str):
- return True, handle_errors
- if isinstance(handle_errors, type) and issubclass(handle_errors, Exception):
- if isinstance(exception, handle_errors):
- return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
- return False, ""
- if isinstance(handle_errors, tuple):
- if any(isinstance(exception, exc_type) for exc_type in handle_errors):
- return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
- return False, ""
- if callable(handle_errors):
- # type narrowing not working appropriately w/ callable check, can fix later
- return True, handle_errors(exception) # type: ignore[return-value,call-arg]
- return False, ""
- def _chain_tool_call_wrappers(
- wrappers: Sequence[ToolCallWrapper],
- ) -> ToolCallWrapper | None:
- """Compose wrappers into middleware stack (first = outermost).
- Args:
- wrappers: Wrappers in middleware order.
- Returns:
- Composed wrapper, or `None` if empty.
- Example:
- wrapper = _chain_tool_call_wrappers([auth, cache, retry])
- # Request flows: auth -> cache -> retry -> tool
- # Response flows: tool -> retry -> cache -> auth
- """
- if not wrappers:
- return None
- if len(wrappers) == 1:
- return wrappers[0]
- def compose_two(outer: ToolCallWrapper, inner: ToolCallWrapper) -> ToolCallWrapper:
- """Compose two wrappers where outer wraps inner."""
- def composed(
- request: ToolCallRequest,
- execute: Callable[[ToolCallRequest], ToolMessage | Command],
- ) -> ToolMessage | Command:
- # Create a callable that invokes inner with the original execute
- def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
- return inner(req, execute)
- # Outer can call call_inner multiple times
- return outer(request, call_inner)
- return composed
- # Chain all wrappers: first -> second -> ... -> last
- result = wrappers[-1]
- for wrapper in reversed(wrappers[:-1]):
- result = compose_two(wrapper, result)
- return result
- def _chain_async_tool_call_wrappers(
- wrappers: Sequence[
- Callable[
- [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
- Awaitable[ToolMessage | Command],
- ]
- ],
- ) -> (
- Callable[
- [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
- Awaitable[ToolMessage | Command],
- ]
- | None
- ):
- """Compose async wrappers into middleware stack (first = outermost).
- Args:
- wrappers: Async wrappers in middleware order.
- Returns:
- Composed async wrapper, or `None` if empty.
- """
- if not wrappers:
- return None
- if len(wrappers) == 1:
- return wrappers[0]
- def compose_two(
- outer: Callable[
- [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
- Awaitable[ToolMessage | Command],
- ],
- inner: Callable[
- [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
- Awaitable[ToolMessage | Command],
- ],
- ) -> Callable[
- [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
- Awaitable[ToolMessage | Command],
- ]:
- """Compose two async wrappers where outer wraps inner."""
- async def composed(
- request: ToolCallRequest,
- execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
- ) -> ToolMessage | Command:
- # Create an async callable that invokes inner with the original execute
- async def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
- return await inner(req, execute)
- # Outer can call call_inner multiple times
- return await outer(request, call_inner)
- return composed
- # Chain all wrappers: first -> second -> ... -> last
- result = wrappers[-1]
- for wrapper in reversed(wrappers[:-1]):
- result = compose_two(wrapper, result)
- return result
- def create_agent( # noqa: PLR0915
- model: str | BaseChatModel,
- tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
- *,
- system_prompt: str | SystemMessage | None = None,
- middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
- response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
- state_schema: type[AgentState[ResponseT]] | None = None,
- context_schema: type[ContextT] | None = None,
- checkpointer: Checkpointer | None = None,
- store: BaseStore | None = None,
- interrupt_before: list[str] | None = None,
- interrupt_after: list[str] | None = None,
- debug: bool = False,
- name: str | None = None,
- cache: BaseCache | None = None,
- ) -> CompiledStateGraph[
- AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
- ]:
- """Creates an agent graph that calls tools in a loop until a stopping condition is met.
- For more details on using `create_agent`,
- visit the [Agents](https://docs.langchain.com/oss/python/langchain/agents) docs.
- Args:
- model: The language model for the agent.
- Can be a string identifier (e.g., `"openai:gpt-4"`) or a direct chat model
- instance (e.g., [`ChatOpenAI`][langchain_openai.ChatOpenAI] or other another
- [LangChain chat model](https://docs.langchain.com/oss/python/integrations/chat)).
- For a full list of supported model strings, see
- [`init_chat_model`][langchain.chat_models.init_chat_model(model_provider)].
- !!! tip ""
- See the [Models](https://docs.langchain.com/oss/python/langchain/models)
- docs for more information.
- tools: A list of tools, `dict`, or `Callable`.
- If `None` or an empty list, the agent will consist of a model node without a
- tool calling loop.
- !!! tip ""
- See the [Tools](https://docs.langchain.com/oss/python/langchain/tools)
- docs for more information.
- system_prompt: An optional system prompt for the LLM.
- Can be a `str` (which will be converted to a `SystemMessage`) or a
- `SystemMessage` instance directly. The system message is added to the
- beginning of the message list when calling the model.
- middleware: A sequence of middleware instances to apply to the agent.
- Middleware can intercept and modify agent behavior at various stages.
- !!! tip ""
- See the [Middleware](https://docs.langchain.com/oss/python/langchain/middleware)
- docs for more information.
- response_format: An optional configuration for structured responses.
- Can be a `ToolStrategy`, `ProviderStrategy`, or a Pydantic model class.
- If provided, the agent will handle structured output during the
- conversation flow.
- Raw schemas will be wrapped in an appropriate strategy based on model
- capabilities.
- !!! tip ""
- See the [Structured output](https://docs.langchain.com/oss/python/langchain/structured-output)
- docs for more information.
- state_schema: An optional `TypedDict` schema that extends `AgentState`.
- When provided, this schema is used instead of `AgentState` as the base
- schema for merging with middleware state schemas. This allows users to
- add custom state fields without needing to create custom middleware.
- Generally, it's recommended to use `state_schema` extensions via middleware
- to keep relevant extensions scoped to corresponding hooks / tools.
- context_schema: An optional schema for runtime context.
- checkpointer: An optional checkpoint saver object.
- Used for persisting the state of the graph (e.g., as chat memory) for a
- single thread (e.g., a single conversation).
- store: An optional store object.
- Used for persisting data across multiple threads (e.g., multiple
- conversations / users).
- interrupt_before: An optional list of node names to interrupt before.
- Useful if you want to add a user confirmation or other interrupt
- before taking an action.
- interrupt_after: An optional list of node names to interrupt after.
- Useful if you want to return directly or run additional processing
- on an output.
- debug: Whether to enable verbose logging for graph execution.
- When enabled, prints detailed information about each node execution, state
- updates, and transitions during agent runtime. Useful for debugging
- middleware behavior and understanding agent execution flow.
- name: An optional name for the `CompiledStateGraph`.
- This name will be automatically used when adding the agent graph to
- another graph as a subgraph node - particularly useful for building
- multi-agent systems.
- cache: An optional `BaseCache` instance to enable caching of graph execution.
- Returns:
- A compiled `StateGraph` that can be used for chat interactions.
- The agent node calls the language model with the messages list (after applying
- the system prompt). If the resulting [`AIMessage`][langchain.messages.AIMessage]
- contains `tool_calls`, the graph will then call the tools. The tools node executes
- the tools and adds the responses to the messages list as
- [`ToolMessage`][langchain.messages.ToolMessage] objects. The agent node then calls
- the language model again. The process repeats until no more `tool_calls` are present
- in the response. The agent then returns the full list of messages.
- Example:
- ```python
- from langchain.agents import create_agent
- def check_weather(location: str) -> str:
- '''Return the weather forecast for the specified location.'''
- return f"It's always sunny in {location}"
- graph = create_agent(
- model="anthropic:claude-sonnet-4-5-20250929",
- tools=[check_weather],
- system_prompt="You are a helpful assistant",
- )
- inputs = {"messages": [{"role": "user", "content": "what is the weather in sf"}]}
- for chunk in graph.stream(inputs, stream_mode="updates"):
- print(chunk)
- ```
- """
- # init chat model
- if isinstance(model, str):
- model = init_chat_model(model)
- # Convert system_prompt to SystemMessage if needed
- system_message: SystemMessage | None = None
- if system_prompt is not None:
- if isinstance(system_prompt, SystemMessage):
- system_message = system_prompt
- else:
- system_message = SystemMessage(content=system_prompt)
- # Handle tools being None or empty
- if tools is None:
- tools = []
- # Convert response format and setup structured output tools
- # Raw schemas are wrapped in AutoStrategy to preserve auto-detection intent.
- # AutoStrategy is converted to ToolStrategy upfront to calculate tools during agent creation,
- # but may be replaced with ProviderStrategy later based on model capabilities.
- initial_response_format: ToolStrategy | ProviderStrategy | AutoStrategy | None
- if response_format is None:
- initial_response_format = None
- elif isinstance(response_format, (ToolStrategy, ProviderStrategy)):
- # Preserve explicitly requested strategies
- initial_response_format = response_format
- elif isinstance(response_format, AutoStrategy):
- # AutoStrategy provided - preserve it for later auto-detection
- initial_response_format = response_format
- else:
- # Raw schema - wrap in AutoStrategy to enable auto-detection
- initial_response_format = AutoStrategy(schema=response_format)
- # For AutoStrategy, convert to ToolStrategy to setup tools upfront
- # (may be replaced with ProviderStrategy later based on model)
- tool_strategy_for_setup: ToolStrategy | None = None
- if isinstance(initial_response_format, AutoStrategy):
- tool_strategy_for_setup = ToolStrategy(schema=initial_response_format.schema)
- elif isinstance(initial_response_format, ToolStrategy):
- tool_strategy_for_setup = initial_response_format
- structured_output_tools: dict[str, OutputToolBinding] = {}
- if tool_strategy_for_setup:
- for response_schema in tool_strategy_for_setup.schema_specs:
- structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
- structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
- middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
- # Collect middleware with wrap_tool_call or awrap_tool_call hooks
- # Include middleware with either implementation to ensure NotImplementedError is raised
- # when middleware doesn't support the execution path
- middleware_w_wrap_tool_call = [
- m
- for m in middleware
- if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
- or m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
- ]
- # Chain all wrap_tool_call handlers into a single composed handler
- wrap_tool_call_wrapper = None
- if middleware_w_wrap_tool_call:
- wrappers = [m.wrap_tool_call for m in middleware_w_wrap_tool_call]
- wrap_tool_call_wrapper = _chain_tool_call_wrappers(wrappers)
- # Collect middleware with awrap_tool_call or wrap_tool_call hooks
- # Include middleware with either implementation to ensure NotImplementedError is raised
- # when middleware doesn't support the execution path
- middleware_w_awrap_tool_call = [
- m
- for m in middleware
- if m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
- or m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
- ]
- # Chain all awrap_tool_call handlers into a single composed async handler
- awrap_tool_call_wrapper = None
- if middleware_w_awrap_tool_call:
- async_wrappers = [m.awrap_tool_call for m in middleware_w_awrap_tool_call]
- awrap_tool_call_wrapper = _chain_async_tool_call_wrappers(async_wrappers)
- # Setup tools
- tool_node: ToolNode | None = None
- # Extract built-in provider tools (dict format) and regular tools (BaseTool/callables)
- built_in_tools = [t for t in tools if isinstance(t, dict)]
- regular_tools = [t for t in tools if not isinstance(t, dict)]
- # Tools that require client-side execution (must be in ToolNode)
- available_tools = middleware_tools + regular_tools
- # Only create ToolNode if we have client-side tools
- tool_node = (
- ToolNode(
- tools=available_tools,
- wrap_tool_call=wrap_tool_call_wrapper,
- awrap_tool_call=awrap_tool_call_wrapper,
- )
- if available_tools
- else None
- )
- # Default tools for ModelRequest initialization
- # Use converted BaseTool instances from ToolNode (not raw callables)
- # Include built-ins and converted tools (can be changed dynamically by middleware)
- # Structured tools are NOT included - they're added dynamically based on response_format
- if tool_node:
- default_tools = list(tool_node.tools_by_name.values()) + built_in_tools
- else:
- default_tools = list(built_in_tools)
- # validate middleware
- assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
- "Please remove duplicate middleware instances."
- )
- middleware_w_before_agent = [
- m
- for m in middleware
- if m.__class__.before_agent is not AgentMiddleware.before_agent
- or m.__class__.abefore_agent is not AgentMiddleware.abefore_agent
- ]
- middleware_w_before_model = [
- m
- for m in middleware
- if m.__class__.before_model is not AgentMiddleware.before_model
- or m.__class__.abefore_model is not AgentMiddleware.abefore_model
- ]
- middleware_w_after_model = [
- m
- for m in middleware
- if m.__class__.after_model is not AgentMiddleware.after_model
- or m.__class__.aafter_model is not AgentMiddleware.aafter_model
- ]
- middleware_w_after_agent = [
- m
- for m in middleware
- if m.__class__.after_agent is not AgentMiddleware.after_agent
- or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
- ]
- # Collect middleware with wrap_model_call or awrap_model_call hooks
- # Include middleware with either implementation to ensure NotImplementedError is raised
- # when middleware doesn't support the execution path
- middleware_w_wrap_model_call = [
- m
- for m in middleware
- if m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
- or m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
- ]
- # Collect middleware with awrap_model_call or wrap_model_call hooks
- # Include middleware with either implementation to ensure NotImplementedError is raised
- # when middleware doesn't support the execution path
- middleware_w_awrap_model_call = [
- m
- for m in middleware
- if m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
- or m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
- ]
- # Compose wrap_model_call handlers into a single middleware stack (sync)
- wrap_model_call_handler = None
- if middleware_w_wrap_model_call:
- sync_handlers = [m.wrap_model_call for m in middleware_w_wrap_model_call]
- wrap_model_call_handler = _chain_model_call_handlers(sync_handlers)
- # Compose awrap_model_call handlers into a single middleware stack (async)
- awrap_model_call_handler = None
- if middleware_w_awrap_model_call:
- async_handlers = [m.awrap_model_call for m in middleware_w_awrap_model_call]
- awrap_model_call_handler = _chain_async_model_call_handlers(async_handlers)
- state_schemas: set[type] = {m.state_schema for m in middleware}
- # Use provided state_schema if available, otherwise use base AgentState
- base_state = state_schema if state_schema is not None else AgentState
- state_schemas.add(base_state)
- resolved_state_schema = _resolve_schema(state_schemas, "StateSchema", None)
- input_schema = _resolve_schema(state_schemas, "InputSchema", "input")
- output_schema = _resolve_schema(state_schemas, "OutputSchema", "output")
- # create graph, add nodes
- graph: StateGraph[
- AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
- ] = StateGraph(
- state_schema=resolved_state_schema,
- input_schema=input_schema,
- output_schema=output_schema,
- context_schema=context_schema,
- )
- def _handle_model_output(
- output: AIMessage, effective_response_format: ResponseFormat | None
- ) -> dict[str, Any]:
- """Handle model output including structured responses.
- Args:
- output: The AI message output from the model.
- effective_response_format: The actual strategy used
- (may differ from initial if auto-detected).
- """
- # Handle structured output with provider strategy
- if isinstance(effective_response_format, ProviderStrategy):
- if not output.tool_calls:
- provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
- effective_response_format.schema_spec
- )
- try:
- structured_response = provider_strategy_binding.parse(output)
- except Exception as exc: # noqa: BLE001
- schema_name = getattr(
- effective_response_format.schema_spec.schema, "__name__", "response_format"
- )
- validation_error = StructuredOutputValidationError(schema_name, exc, output)
- raise validation_error
- else:
- return {"messages": [output], "structured_response": structured_response}
- return {"messages": [output]}
- # Handle structured output with tool strategy
- if (
- isinstance(effective_response_format, ToolStrategy)
- and isinstance(output, AIMessage)
- and output.tool_calls
- ):
- structured_tool_calls = [
- tc for tc in output.tool_calls if tc["name"] in structured_output_tools
- ]
- if structured_tool_calls:
- exception: StructuredOutputError | None = None
- if len(structured_tool_calls) > 1:
- # Handle multiple structured outputs error
- tool_names = [tc["name"] for tc in structured_tool_calls]
- exception = MultipleStructuredOutputsError(tool_names, output)
- should_retry, error_message = _handle_structured_output_error(
- exception, effective_response_format
- )
- if not should_retry:
- raise exception
- # Add error messages and retry
- tool_messages = [
- ToolMessage(
- content=error_message,
- tool_call_id=tc["id"],
- name=tc["name"],
- )
- for tc in structured_tool_calls
- ]
- return {"messages": [output, *tool_messages]}
- # Handle single structured output
- tool_call = structured_tool_calls[0]
- try:
- structured_tool_binding = structured_output_tools[tool_call["name"]]
- structured_response = structured_tool_binding.parse(tool_call["args"])
- tool_message_content = (
- effective_response_format.tool_message_content
- if effective_response_format.tool_message_content
- else f"Returning structured response: {structured_response}"
- )
- return {
- "messages": [
- output,
- ToolMessage(
- content=tool_message_content,
- tool_call_id=tool_call["id"],
- name=tool_call["name"],
- ),
- ],
- "structured_response": structured_response,
- }
- except Exception as exc: # noqa: BLE001
- exception = StructuredOutputValidationError(tool_call["name"], exc, output)
- should_retry, error_message = _handle_structured_output_error(
- exception, effective_response_format
- )
- if not should_retry:
- raise exception
- return {
- "messages": [
- output,
- ToolMessage(
- content=error_message,
- tool_call_id=tool_call["id"],
- name=tool_call["name"],
- ),
- ],
- }
- return {"messages": [output]}
- def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
- """Get the model with appropriate tool bindings.
- Performs auto-detection of strategy if needed based on model capabilities.
- Args:
- request: The model request containing model, tools, and response format.
- Returns:
- Tuple of `(bound_model, effective_response_format)` where
- `effective_response_format` is the actual strategy used (may differ from
- initial if auto-detected).
- """
- # Validate ONLY client-side tools that need to exist in tool_node
- # Build map of available client-side tools from the ToolNode
- # (which has already converted callables)
- available_tools_by_name = {}
- if tool_node:
- available_tools_by_name = tool_node.tools_by_name.copy()
- # Check if any requested tools are unknown CLIENT-SIDE tools
- unknown_tool_names = []
- for t in request.tools:
- # Only validate BaseTool instances (skip built-in dict tools)
- if isinstance(t, dict):
- continue
- if isinstance(t, BaseTool) and t.name not in available_tools_by_name:
- unknown_tool_names.append(t.name)
- if unknown_tool_names:
- available_tool_names = sorted(available_tools_by_name.keys())
- msg = (
- f"Middleware returned unknown tool names: {unknown_tool_names}\n\n"
- f"Available client-side tools: {available_tool_names}\n\n"
- "To fix this issue:\n"
- "1. Ensure the tools are passed to create_agent() via "
- "the 'tools' parameter\n"
- "2. If using custom middleware with tools, ensure "
- "they're registered via middleware.tools attribute\n"
- "3. Verify that tool names in ModelRequest.tools match "
- "the actual tool.name values\n"
- "Note: Built-in provider tools (dict format) can be added dynamically."
- )
- raise ValueError(msg)
- # Determine effective response format (auto-detect if needed)
- effective_response_format: ResponseFormat | None
- if isinstance(request.response_format, AutoStrategy):
- # User provided raw schema via AutoStrategy - auto-detect best strategy based on model
- if _supports_provider_strategy(request.model, tools=request.tools):
- # Model supports provider strategy - use it
- effective_response_format = ProviderStrategy(schema=request.response_format.schema)
- else:
- # Model doesn't support provider strategy - use ToolStrategy
- effective_response_format = ToolStrategy(schema=request.response_format.schema)
- else:
- # User explicitly specified a strategy - preserve it
- effective_response_format = request.response_format
- # Build final tools list including structured output tools
- # request.tools now only contains BaseTool instances (converted from callables)
- # and dicts (built-ins)
- final_tools = list(request.tools)
- if isinstance(effective_response_format, ToolStrategy):
- # Add structured output tools to final tools list
- structured_tools = [info.tool for info in structured_output_tools.values()]
- final_tools.extend(structured_tools)
- # Bind model based on effective response format
- if isinstance(effective_response_format, ProviderStrategy):
- # (Backward compatibility) Use OpenAI format structured output
- kwargs = effective_response_format.to_model_kwargs()
- return (
- request.model.bind_tools(
- final_tools, strict=True, **kwargs, **request.model_settings
- ),
- effective_response_format,
- )
- if isinstance(effective_response_format, ToolStrategy):
- # Current implementation requires that tools used for structured output
- # have to be declared upfront when creating the agent as part of the
- # response format. Middleware is allowed to change the response format
- # to a subset of the original structured tools when using ToolStrategy,
- # but not to add new structured tools that weren't declared upfront.
- # Compute output binding
- for tc in effective_response_format.schema_specs:
- if tc.name not in structured_output_tools:
- msg = (
- f"ToolStrategy specifies tool '{tc.name}' "
- "which wasn't declared in the original "
- "response format when creating the agent."
- )
- raise ValueError(msg)
- # Force tool use if we have structured output tools
- tool_choice = "any" if structured_output_tools else request.tool_choice
- return (
- request.model.bind_tools(
- final_tools, tool_choice=tool_choice, **request.model_settings
- ),
- effective_response_format,
- )
- # No structured output - standard model binding
- if final_tools:
- return (
- request.model.bind_tools(
- final_tools, tool_choice=request.tool_choice, **request.model_settings
- ),
- None,
- )
- return request.model.bind(**request.model_settings), None
- def _execute_model_sync(request: ModelRequest) -> ModelResponse:
- """Execute model and return response.
- This is the core model execution logic wrapped by `wrap_model_call` handlers.
- Raises any exceptions that occur during model invocation.
- """
- # Get the bound model (with auto-detection if needed)
- model_, effective_response_format = _get_bound_model(request)
- messages = request.messages
- if request.system_message:
- messages = [request.system_message, *messages]
- output = model_.invoke(messages)
- if name:
- output.name = name
- # Handle model output to get messages and structured_response
- handled_output = _handle_model_output(output, effective_response_format)
- messages_list = handled_output["messages"]
- structured_response = handled_output.get("structured_response")
- return ModelResponse(
- result=messages_list,
- structured_response=structured_response,
- )
- def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
- """Sync model request handler with sequential middleware processing."""
- request = ModelRequest(
- model=model,
- tools=default_tools,
- system_message=system_message,
- response_format=initial_response_format,
- messages=state["messages"],
- tool_choice=None,
- state=state,
- runtime=runtime,
- )
- if wrap_model_call_handler is None:
- # No handlers - execute directly
- response = _execute_model_sync(request)
- else:
- # Call composed handler with base handler
- response = wrap_model_call_handler(request, _execute_model_sync)
- # Extract state updates from ModelResponse
- state_updates = {"messages": response.result}
- if response.structured_response is not None:
- state_updates["structured_response"] = response.structured_response
- return state_updates
- async def _execute_model_async(request: ModelRequest) -> ModelResponse:
- """Execute model asynchronously and return response.
- This is the core async model execution logic wrapped by `wrap_model_call`
- handlers.
- Raises any exceptions that occur during model invocation.
- """
- # Get the bound model (with auto-detection if needed)
- model_, effective_response_format = _get_bound_model(request)
- messages = request.messages
- if request.system_message:
- messages = [request.system_message, *messages]
- output = await model_.ainvoke(messages)
- if name:
- output.name = name
- # Handle model output to get messages and structured_response
- handled_output = _handle_model_output(output, effective_response_format)
- messages_list = handled_output["messages"]
- structured_response = handled_output.get("structured_response")
- return ModelResponse(
- result=messages_list,
- structured_response=structured_response,
- )
- async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
- """Async model request handler with sequential middleware processing."""
- request = ModelRequest(
- model=model,
- tools=default_tools,
- system_message=system_message,
- response_format=initial_response_format,
- messages=state["messages"],
- tool_choice=None,
- state=state,
- runtime=runtime,
- )
- if awrap_model_call_handler is None:
- # No async handlers - execute directly
- response = await _execute_model_async(request)
- else:
- # Call composed async handler with base handler
- response = await awrap_model_call_handler(request, _execute_model_async)
- # Extract state updates from ModelResponse
- state_updates = {"messages": response.result}
- if response.structured_response is not None:
- state_updates["structured_response"] = response.structured_response
- return state_updates
- # Use sync or async based on model capabilities
- graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))
- # Only add tools node if we have tools
- if tool_node is not None:
- graph.add_node("tools", tool_node)
- # Add middleware nodes
- for m in middleware:
- if (
- m.__class__.before_agent is not AgentMiddleware.before_agent
- or m.__class__.abefore_agent is not AgentMiddleware.abefore_agent
- ):
- # Use RunnableCallable to support both sync and async
- # Pass None for sync if not overridden to avoid signature conflicts
- sync_before_agent = (
- m.before_agent
- if m.__class__.before_agent is not AgentMiddleware.before_agent
- else None
- )
- async_before_agent = (
- m.abefore_agent
- if m.__class__.abefore_agent is not AgentMiddleware.abefore_agent
- else None
- )
- before_agent_node = RunnableCallable(sync_before_agent, async_before_agent, trace=False)
- graph.add_node(
- f"{m.name}.before_agent", before_agent_node, input_schema=resolved_state_schema
- )
- if (
- m.__class__.before_model is not AgentMiddleware.before_model
- or m.__class__.abefore_model is not AgentMiddleware.abefore_model
- ):
- # Use RunnableCallable to support both sync and async
- # Pass None for sync if not overridden to avoid signature conflicts
- sync_before = (
- m.before_model
- if m.__class__.before_model is not AgentMiddleware.before_model
- else None
- )
- async_before = (
- m.abefore_model
- if m.__class__.abefore_model is not AgentMiddleware.abefore_model
- else None
- )
- before_node = RunnableCallable(sync_before, async_before, trace=False)
- graph.add_node(
- f"{m.name}.before_model", before_node, input_schema=resolved_state_schema
- )
- if (
- m.__class__.after_model is not AgentMiddleware.after_model
- or m.__class__.aafter_model is not AgentMiddleware.aafter_model
- ):
- # Use RunnableCallable to support both sync and async
- # Pass None for sync if not overridden to avoid signature conflicts
- sync_after = (
- m.after_model
- if m.__class__.after_model is not AgentMiddleware.after_model
- else None
- )
- async_after = (
- m.aafter_model
- if m.__class__.aafter_model is not AgentMiddleware.aafter_model
- else None
- )
- after_node = RunnableCallable(sync_after, async_after, trace=False)
- graph.add_node(f"{m.name}.after_model", after_node, input_schema=resolved_state_schema)
- if (
- m.__class__.after_agent is not AgentMiddleware.after_agent
- or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
- ):
- # Use RunnableCallable to support both sync and async
- # Pass None for sync if not overridden to avoid signature conflicts
- sync_after_agent = (
- m.after_agent
- if m.__class__.after_agent is not AgentMiddleware.after_agent
- else None
- )
- async_after_agent = (
- m.aafter_agent
- if m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
- else None
- )
- after_agent_node = RunnableCallable(sync_after_agent, async_after_agent, trace=False)
- graph.add_node(
- f"{m.name}.after_agent", after_agent_node, input_schema=resolved_state_schema
- )
- # Determine the entry node (runs once at start): before_agent -> before_model -> model
- if middleware_w_before_agent:
- entry_node = f"{middleware_w_before_agent[0].name}.before_agent"
- elif middleware_w_before_model:
- entry_node = f"{middleware_w_before_model[0].name}.before_model"
- else:
- entry_node = "model"
- # Determine the loop entry node (beginning of agent loop, excludes before_agent)
- # This is where tools will loop back to for the next iteration
- if middleware_w_before_model:
- loop_entry_node = f"{middleware_w_before_model[0].name}.before_model"
- else:
- loop_entry_node = "model"
- # Determine the loop exit node (end of each iteration, can run multiple times)
- # This is after_model or model, but NOT after_agent
- if middleware_w_after_model:
- loop_exit_node = f"{middleware_w_after_model[0].name}.after_model"
- else:
- loop_exit_node = "model"
- # Determine the exit node (runs once at end): after_agent or END
- if middleware_w_after_agent:
- exit_node = f"{middleware_w_after_agent[-1].name}.after_agent"
- else:
- exit_node = END
- graph.add_edge(START, entry_node)
- # add conditional edges only if tools exist
- if tool_node is not None:
- # Only include exit_node in destinations if any tool has return_direct=True
- # or if there are structured output tools
- tools_to_model_destinations = [loop_entry_node]
- if (
- any(tool.return_direct for tool in tool_node.tools_by_name.values())
- or structured_output_tools
- ):
- tools_to_model_destinations.append(exit_node)
- graph.add_conditional_edges(
- "tools",
- RunnableCallable(
- _make_tools_to_model_edge(
- tool_node=tool_node,
- model_destination=loop_entry_node,
- structured_output_tools=structured_output_tools,
- end_destination=exit_node,
- ),
- trace=False,
- ),
- tools_to_model_destinations,
- )
- # base destinations are tools and exit_node
- # we add the loop_entry node to edge destinations if:
- # - there is an after model hook(s) -- allows jump_to to model
- # potentially artificially injected tool messages, ex HITL
- # - there is a response format -- to allow for jumping to model to handle
- # regenerating structured output tool calls
- model_to_tools_destinations = ["tools", exit_node]
- if response_format or loop_exit_node != "model":
- model_to_tools_destinations.append(loop_entry_node)
- graph.add_conditional_edges(
- loop_exit_node,
- RunnableCallable(
- _make_model_to_tools_edge(
- model_destination=loop_entry_node,
- structured_output_tools=structured_output_tools,
- end_destination=exit_node,
- ),
- trace=False,
- ),
- model_to_tools_destinations,
- )
- elif len(structured_output_tools) > 0:
- graph.add_conditional_edges(
- loop_exit_node,
- RunnableCallable(
- _make_model_to_model_edge(
- model_destination=loop_entry_node,
- end_destination=exit_node,
- ),
- trace=False,
- ),
- [loop_entry_node, exit_node],
- )
- elif loop_exit_node == "model":
- # If no tools and no after_model, go directly to exit_node
- graph.add_edge(loop_exit_node, exit_node)
- # No tools but we have after_model - connect after_model to exit_node
- else:
- _add_middleware_edge(
- graph,
- name=f"{middleware_w_after_model[0].name}.after_model",
- default_destination=exit_node,
- model_destination=loop_entry_node,
- end_destination=exit_node,
- can_jump_to=_get_can_jump_to(middleware_w_after_model[0], "after_model"),
- )
- # Add before_agent middleware edges
- if middleware_w_before_agent:
- for m1, m2 in itertools.pairwise(middleware_w_before_agent):
- _add_middleware_edge(
- graph,
- name=f"{m1.name}.before_agent",
- default_destination=f"{m2.name}.before_agent",
- model_destination=loop_entry_node,
- end_destination=exit_node,
- can_jump_to=_get_can_jump_to(m1, "before_agent"),
- )
- # Connect last before_agent to loop_entry_node (before_model or model)
- _add_middleware_edge(
- graph,
- name=f"{middleware_w_before_agent[-1].name}.before_agent",
- default_destination=loop_entry_node,
- model_destination=loop_entry_node,
- end_destination=exit_node,
- can_jump_to=_get_can_jump_to(middleware_w_before_agent[-1], "before_agent"),
- )
- # Add before_model middleware edges
- if middleware_w_before_model:
- for m1, m2 in itertools.pairwise(middleware_w_before_model):
- _add_middleware_edge(
- graph,
- name=f"{m1.name}.before_model",
- default_destination=f"{m2.name}.before_model",
- model_destination=loop_entry_node,
- end_destination=exit_node,
- can_jump_to=_get_can_jump_to(m1, "before_model"),
- )
- # Go directly to model after the last before_model
- _add_middleware_edge(
- graph,
- name=f"{middleware_w_before_model[-1].name}.before_model",
- default_destination="model",
- model_destination=loop_entry_node,
- end_destination=exit_node,
- can_jump_to=_get_can_jump_to(middleware_w_before_model[-1], "before_model"),
- )
- # Add after_model middleware edges
- if middleware_w_after_model:
- graph.add_edge("model", f"{middleware_w_after_model[-1].name}.after_model")
- for idx in range(len(middleware_w_after_model) - 1, 0, -1):
- m1 = middleware_w_after_model[idx]
- m2 = middleware_w_after_model[idx - 1]
- _add_middleware_edge(
- graph,
- name=f"{m1.name}.after_model",
- default_destination=f"{m2.name}.after_model",
- model_destination=loop_entry_node,
- end_destination=exit_node,
- can_jump_to=_get_can_jump_to(m1, "after_model"),
- )
- # Note: Connection from after_model to after_agent/END is handled above
- # in the conditional edges section
- # Add after_agent middleware edges
- if middleware_w_after_agent:
- # Chain after_agent middleware (runs once at the very end, before END)
- for idx in range(len(middleware_w_after_agent) - 1, 0, -1):
- m1 = middleware_w_after_agent[idx]
- m2 = middleware_w_after_agent[idx - 1]
- _add_middleware_edge(
- graph,
- name=f"{m1.name}.after_agent",
- default_destination=f"{m2.name}.after_agent",
- model_destination=loop_entry_node,
- end_destination=exit_node,
- can_jump_to=_get_can_jump_to(m1, "after_agent"),
- )
- # Connect the last after_agent to END
- _add_middleware_edge(
- graph,
- name=f"{middleware_w_after_agent[0].name}.after_agent",
- default_destination=END,
- model_destination=loop_entry_node,
- end_destination=exit_node,
- can_jump_to=_get_can_jump_to(middleware_w_after_agent[0], "after_agent"),
- )
- return graph.compile(
- checkpointer=checkpointer,
- store=store,
- interrupt_before=interrupt_before,
- interrupt_after=interrupt_after,
- debug=debug,
- name=name,
- cache=cache,
- ).with_config({"recursion_limit": 10_000})
- def _resolve_jump(
- jump_to: JumpTo | None,
- *,
- model_destination: str,
- end_destination: str,
- ) -> str | None:
- if jump_to == "model":
- return model_destination
- if jump_to == "end":
- return end_destination
- if jump_to == "tools":
- return "tools"
- return None
- def _fetch_last_ai_and_tool_messages(
- messages: list[AnyMessage],
- ) -> tuple[AIMessage, list[ToolMessage]]:
- last_ai_index: int
- last_ai_message: AIMessage
- for i in range(len(messages) - 1, -1, -1):
- if isinstance(messages[i], AIMessage):
- last_ai_index = i
- last_ai_message = cast("AIMessage", messages[i])
- break
- tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)]
- return last_ai_message, tool_messages
- def _make_model_to_tools_edge(
- *,
- model_destination: str,
- structured_output_tools: dict[str, OutputToolBinding],
- end_destination: str,
- ) -> Callable[[dict[str, Any]], str | list[Send] | None]:
- def model_to_tools(
- state: dict[str, Any],
- ) -> str | list[Send] | None:
- # 1. if there's an explicit jump_to in the state, use it
- if jump_to := state.get("jump_to"):
- return _resolve_jump(
- jump_to,
- model_destination=model_destination,
- end_destination=end_destination,
- )
- last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
- tool_message_ids = [m.tool_call_id for m in tool_messages]
- # 2. if the model hasn't called any tools, exit the loop
- # this is the classic exit condition for an agent loop
- if len(last_ai_message.tool_calls) == 0:
- return end_destination
- pending_tool_calls = [
- c
- for c in last_ai_message.tool_calls
- if c["id"] not in tool_message_ids and c["name"] not in structured_output_tools
- ]
- # 3. if there are pending tool calls, jump to the tool node
- if pending_tool_calls:
- return [
- Send(
- "tools",
- ToolCallWithContext(
- __type="tool_call_with_context",
- tool_call=tool_call,
- state=state,
- ),
- )
- for tool_call in pending_tool_calls
- ]
- # 4. if there is a structured response, exit the loop
- if "structured_response" in state:
- return end_destination
- # 5. AIMessage has tool calls, but there are no pending tool calls
- # which suggests the injection of artificial tool messages. jump to the model node
- return model_destination
- return model_to_tools
- def _make_model_to_model_edge(
- *,
- model_destination: str,
- end_destination: str,
- ) -> Callable[[dict[str, Any]], str | list[Send] | None]:
- def model_to_model(
- state: dict[str, Any],
- ) -> str | list[Send] | None:
- # 1. Priority: Check for explicit jump_to directive from middleware
- if jump_to := state.get("jump_to"):
- return _resolve_jump(
- jump_to,
- model_destination=model_destination,
- end_destination=end_destination,
- )
- # 2. Exit condition: A structured response was generated
- if "structured_response" in state:
- return end_destination
- # 3. Default: Continue the loop, there may have been an issue
- # with structured output generation, so we need to retry
- return model_destination
- return model_to_model
- def _make_tools_to_model_edge(
- *,
- tool_node: ToolNode,
- model_destination: str,
- structured_output_tools: dict[str, OutputToolBinding],
- end_destination: str,
- ) -> Callable[[dict[str, Any]], str | None]:
- def tools_to_model(state: dict[str, Any]) -> str | None:
- last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
- # 1. Exit condition: All executed tools have return_direct=True
- # Filter to only client-side tools (provider tools are not in tool_node)
- client_side_tool_calls = [
- c for c in last_ai_message.tool_calls if c["name"] in tool_node.tools_by_name
- ]
- if client_side_tool_calls and all(
- tool_node.tools_by_name[c["name"]].return_direct for c in client_side_tool_calls
- ):
- return end_destination
- # 2. Exit condition: A structured output tool was executed
- if any(t.name in structured_output_tools for t in tool_messages):
- return end_destination
- # 3. Default: Continue the loop
- # Tool execution completed successfully, route back to the model
- # so it can process the tool results and decide the next action.
- return model_destination
- return tools_to_model
- def _add_middleware_edge(
- graph: StateGraph[
- AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
- ],
- *,
- name: str,
- default_destination: str,
- model_destination: str,
- end_destination: str,
- can_jump_to: list[JumpTo] | None,
- ) -> None:
- """Add an edge to the graph for a middleware node.
- Args:
- graph: The graph to add the edge to.
- name: The name of the middleware node.
- default_destination: The default destination for the edge.
- model_destination: The destination for the edge to the model.
- end_destination: The destination for the edge to the end.
- can_jump_to: The conditionally jumpable destinations for the edge.
- """
- if can_jump_to:
- def jump_edge(state: dict[str, Any]) -> str:
- return (
- _resolve_jump(
- state.get("jump_to"),
- model_destination=model_destination,
- end_destination=end_destination,
- )
- or default_destination
- )
- destinations = [default_destination]
- if "end" in can_jump_to:
- destinations.append(end_destination)
- if "tools" in can_jump_to:
- destinations.append("tools")
- if "model" in can_jump_to and name != model_destination:
- destinations.append(model_destination)
- graph.add_conditional_edges(name, RunnableCallable(jump_edge, trace=False), destinations)
- else:
- graph.add_edge(name, default_destination)
- __all__ = [
- "create_agent",
- ]
|