factory.py 64 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682
  1. """Agent factory for creating agents with middleware support."""
  2. from __future__ import annotations
  3. import itertools
  4. from typing import (
  5. TYPE_CHECKING,
  6. Annotated,
  7. Any,
  8. cast,
  9. get_args,
  10. get_origin,
  11. get_type_hints,
  12. )
  13. from langchain_core.language_models.chat_models import BaseChatModel
  14. from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
  15. from langchain_core.tools import BaseTool
  16. from langgraph._internal._runnable import RunnableCallable
  17. from langgraph.constants import END, START
  18. from langgraph.graph.state import StateGraph
  19. from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode
  20. from langgraph.runtime import Runtime # noqa: TC002
  21. from langgraph.types import Command, Send
  22. from langgraph.typing import ContextT # noqa: TC002
  23. from typing_extensions import NotRequired, Required, TypedDict
  24. from langchain.agents.middleware.types import (
  25. AgentMiddleware,
  26. AgentState,
  27. JumpTo,
  28. ModelRequest,
  29. ModelResponse,
  30. OmitFromSchema,
  31. ResponseT,
  32. StateT_co,
  33. _InputAgentState,
  34. _OutputAgentState,
  35. )
  36. from langchain.agents.structured_output import (
  37. AutoStrategy,
  38. MultipleStructuredOutputsError,
  39. OutputToolBinding,
  40. ProviderStrategy,
  41. ProviderStrategyBinding,
  42. ResponseFormat,
  43. StructuredOutputError,
  44. StructuredOutputValidationError,
  45. ToolStrategy,
  46. )
  47. from langchain.chat_models import init_chat_model
  48. if TYPE_CHECKING:
  49. from collections.abc import Awaitable, Callable, Sequence
  50. from langchain_core.runnables import Runnable
  51. from langgraph.cache.base import BaseCache
  52. from langgraph.graph.state import CompiledStateGraph
  53. from langgraph.store.base import BaseStore
  54. from langgraph.types import Checkpointer
  55. from langchain.agents.middleware.types import ToolCallRequest, ToolCallWrapper
  56. STRUCTURED_OUTPUT_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes."
  57. FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT = [
  58. # if model profile data are not available, these models are assumed to support
  59. # structured output
  60. "grok",
  61. "gpt-5",
  62. "gpt-4.1",
  63. "gpt-4o",
  64. "gpt-oss",
  65. "o3-pro",
  66. "o3-mini",
  67. ]
  68. def _normalize_to_model_response(result: ModelResponse | AIMessage) -> ModelResponse:
  69. """Normalize middleware return value to ModelResponse."""
  70. if isinstance(result, AIMessage):
  71. return ModelResponse(result=[result], structured_response=None)
  72. return result
  73. def _chain_model_call_handlers(
  74. handlers: Sequence[
  75. Callable[
  76. [ModelRequest, Callable[[ModelRequest], ModelResponse]],
  77. ModelResponse | AIMessage,
  78. ]
  79. ],
  80. ) -> (
  81. Callable[
  82. [ModelRequest, Callable[[ModelRequest], ModelResponse]],
  83. ModelResponse,
  84. ]
  85. | None
  86. ):
  87. """Compose multiple wrap_model_call handlers into single middleware stack.
  88. Composes handlers so first in list becomes outermost layer. Each handler
  89. receives a handler callback to execute inner layers.
  90. Args:
  91. handlers: List of handlers. First handler wraps all others.
  92. Returns:
  93. Composed handler, or `None` if handlers empty.
  94. Example:
  95. ```python
  96. # handlers=[auth, retry] means: auth wraps retry
  97. # Flow: auth calls retry, retry calls base handler
  98. def auth(req, state, runtime, handler):
  99. try:
  100. return handler(req)
  101. except UnauthorizedError:
  102. refresh_token()
  103. return handler(req)
  104. def retry(req, state, runtime, handler):
  105. for attempt in range(3):
  106. try:
  107. return handler(req)
  108. except Exception:
  109. if attempt == 2:
  110. raise
  111. handler = _chain_model_call_handlers([auth, retry])
  112. ```
  113. """
  114. if not handlers:
  115. return None
  116. if len(handlers) == 1:
  117. # Single handler - wrap to normalize output
  118. single_handler = handlers[0]
  119. def normalized_single(
  120. request: ModelRequest,
  121. handler: Callable[[ModelRequest], ModelResponse],
  122. ) -> ModelResponse:
  123. result = single_handler(request, handler)
  124. return _normalize_to_model_response(result)
  125. return normalized_single
  126. def compose_two(
  127. outer: Callable[
  128. [ModelRequest, Callable[[ModelRequest], ModelResponse]],
  129. ModelResponse | AIMessage,
  130. ],
  131. inner: Callable[
  132. [ModelRequest, Callable[[ModelRequest], ModelResponse]],
  133. ModelResponse | AIMessage,
  134. ],
  135. ) -> Callable[
  136. [ModelRequest, Callable[[ModelRequest], ModelResponse]],
  137. ModelResponse,
  138. ]:
  139. """Compose two handlers where outer wraps inner."""
  140. def composed(
  141. request: ModelRequest,
  142. handler: Callable[[ModelRequest], ModelResponse],
  143. ) -> ModelResponse:
  144. # Create a wrapper that calls inner with the base handler and normalizes
  145. def inner_handler(req: ModelRequest) -> ModelResponse:
  146. inner_result = inner(req, handler)
  147. return _normalize_to_model_response(inner_result)
  148. # Call outer with the wrapped inner as its handler and normalize
  149. outer_result = outer(request, inner_handler)
  150. return _normalize_to_model_response(outer_result)
  151. return composed
  152. # Compose right-to-left: outer(inner(innermost(handler)))
  153. result = handlers[-1]
  154. for handler in reversed(handlers[:-1]):
  155. result = compose_two(handler, result)
  156. # Wrap to ensure final return type is exactly ModelResponse
  157. def final_normalized(
  158. request: ModelRequest,
  159. handler: Callable[[ModelRequest], ModelResponse],
  160. ) -> ModelResponse:
  161. # result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
  162. final_result = result(request, handler)
  163. return _normalize_to_model_response(final_result)
  164. return final_normalized
  165. def _chain_async_model_call_handlers(
  166. handlers: Sequence[
  167. Callable[
  168. [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
  169. Awaitable[ModelResponse | AIMessage],
  170. ]
  171. ],
  172. ) -> (
  173. Callable[
  174. [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
  175. Awaitable[ModelResponse],
  176. ]
  177. | None
  178. ):
  179. """Compose multiple async `wrap_model_call` handlers into single middleware stack.
  180. Args:
  181. handlers: List of async handlers. First handler wraps all others.
  182. Returns:
  183. Composed async handler, or `None` if handlers empty.
  184. """
  185. if not handlers:
  186. return None
  187. if len(handlers) == 1:
  188. # Single handler - wrap to normalize output
  189. single_handler = handlers[0]
  190. async def normalized_single(
  191. request: ModelRequest,
  192. handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
  193. ) -> ModelResponse:
  194. result = await single_handler(request, handler)
  195. return _normalize_to_model_response(result)
  196. return normalized_single
  197. def compose_two(
  198. outer: Callable[
  199. [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
  200. Awaitable[ModelResponse | AIMessage],
  201. ],
  202. inner: Callable[
  203. [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
  204. Awaitable[ModelResponse | AIMessage],
  205. ],
  206. ) -> Callable[
  207. [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]],
  208. Awaitable[ModelResponse],
  209. ]:
  210. """Compose two async handlers where outer wraps inner."""
  211. async def composed(
  212. request: ModelRequest,
  213. handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
  214. ) -> ModelResponse:
  215. # Create a wrapper that calls inner with the base handler and normalizes
  216. async def inner_handler(req: ModelRequest) -> ModelResponse:
  217. inner_result = await inner(req, handler)
  218. return _normalize_to_model_response(inner_result)
  219. # Call outer with the wrapped inner as its handler and normalize
  220. outer_result = await outer(request, inner_handler)
  221. return _normalize_to_model_response(outer_result)
  222. return composed
  223. # Compose right-to-left: outer(inner(innermost(handler)))
  224. result = handlers[-1]
  225. for handler in reversed(handlers[:-1]):
  226. result = compose_two(handler, result)
  227. # Wrap to ensure final return type is exactly ModelResponse
  228. async def final_normalized(
  229. request: ModelRequest,
  230. handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
  231. ) -> ModelResponse:
  232. # result here is typed as returning ModelResponse | AIMessage but compose_two normalizes
  233. final_result = await result(request, handler)
  234. return _normalize_to_model_response(final_result)
  235. return final_normalized
  236. def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None = None) -> type:
  237. """Resolve schema by merging schemas and optionally respecting `OmitFromSchema` annotations.
  238. Args:
  239. schemas: List of schema types to merge
  240. schema_name: Name for the generated `TypedDict`
  241. omit_flag: If specified, omit fields with this flag set (`'input'` or
  242. `'output'`)
  243. """
  244. all_annotations = {}
  245. for schema in schemas:
  246. hints = get_type_hints(schema, include_extras=True)
  247. for field_name, field_type in hints.items():
  248. should_omit = False
  249. if omit_flag:
  250. # Check for omission in the annotation metadata
  251. metadata = _extract_metadata(field_type)
  252. for meta in metadata:
  253. if isinstance(meta, OmitFromSchema) and getattr(meta, omit_flag) is True:
  254. should_omit = True
  255. break
  256. if not should_omit:
  257. all_annotations[field_name] = field_type
  258. return TypedDict(schema_name, all_annotations) # type: ignore[operator]
  259. def _extract_metadata(type_: type) -> list:
  260. """Extract metadata from a field type, handling Required/NotRequired and Annotated wrappers."""
  261. # Handle Required[Annotated[...]] or NotRequired[Annotated[...]]
  262. if get_origin(type_) in (Required, NotRequired):
  263. inner_type = get_args(type_)[0]
  264. if get_origin(inner_type) is Annotated:
  265. return list(get_args(inner_type)[1:])
  266. # Handle direct Annotated[...]
  267. elif get_origin(type_) is Annotated:
  268. return list(get_args(type_)[1:])
  269. return []
  270. def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> list[JumpTo]:
  271. """Get the `can_jump_to` list from either sync or async hook methods.
  272. Args:
  273. middleware: The middleware instance to inspect.
  274. hook_name: The name of the hook (`'before_model'` or `'after_model'`).
  275. Returns:
  276. List of jump destinations, or empty list if not configured.
  277. """
  278. # Get the base class method for comparison
  279. base_sync_method = getattr(AgentMiddleware, hook_name, None)
  280. base_async_method = getattr(AgentMiddleware, f"a{hook_name}", None)
  281. # Try sync method first - only if it's overridden from base class
  282. sync_method = getattr(middleware.__class__, hook_name, None)
  283. if (
  284. sync_method
  285. and sync_method is not base_sync_method
  286. and hasattr(sync_method, "__can_jump_to__")
  287. ):
  288. return sync_method.__can_jump_to__
  289. # Try async method - only if it's overridden from base class
  290. async_method = getattr(middleware.__class__, f"a{hook_name}", None)
  291. if (
  292. async_method
  293. and async_method is not base_async_method
  294. and hasattr(async_method, "__can_jump_to__")
  295. ):
  296. return async_method.__can_jump_to__
  297. return []
  298. def _supports_provider_strategy(model: str | BaseChatModel, tools: list | None = None) -> bool:
  299. """Check if a model supports provider-specific structured output.
  300. Args:
  301. model: Model name string or `BaseChatModel` instance.
  302. tools: Optional list of tools provided to the agent. Needed because some models
  303. don't support structured output together with tool calling.
  304. Returns:
  305. `True` if the model supports provider-specific structured output, `False` otherwise.
  306. """
  307. model_name: str | None = None
  308. if isinstance(model, str):
  309. model_name = model
  310. elif isinstance(model, BaseChatModel):
  311. model_name = (
  312. getattr(model, "model_name", None)
  313. or getattr(model, "model", None)
  314. or getattr(model, "model_id", "")
  315. )
  316. model_profile = model.profile
  317. if (
  318. model_profile is not None
  319. and model_profile.get("structured_output")
  320. # We make an exception for Gemini models, which currently do not support
  321. # simultaneous tool use with structured output
  322. and not (tools and isinstance(model_name, str) and "gemini" in model_name.lower())
  323. ):
  324. return True
  325. return (
  326. any(part in model_name.lower() for part in FALLBACK_MODELS_WITH_STRUCTURED_OUTPUT)
  327. if model_name
  328. else False
  329. )
  330. def _handle_structured_output_error(
  331. exception: Exception,
  332. response_format: ResponseFormat,
  333. ) -> tuple[bool, str]:
  334. """Handle structured output error. Returns `(should_retry, retry_tool_message)`."""
  335. if not isinstance(response_format, ToolStrategy):
  336. return False, ""
  337. handle_errors = response_format.handle_errors
  338. if handle_errors is False:
  339. return False, ""
  340. if handle_errors is True:
  341. return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
  342. if isinstance(handle_errors, str):
  343. return True, handle_errors
  344. if isinstance(handle_errors, type) and issubclass(handle_errors, Exception):
  345. if isinstance(exception, handle_errors):
  346. return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
  347. return False, ""
  348. if isinstance(handle_errors, tuple):
  349. if any(isinstance(exception, exc_type) for exc_type in handle_errors):
  350. return True, STRUCTURED_OUTPUT_ERROR_TEMPLATE.format(error=str(exception))
  351. return False, ""
  352. if callable(handle_errors):
  353. # type narrowing not working appropriately w/ callable check, can fix later
  354. return True, handle_errors(exception) # type: ignore[return-value,call-arg]
  355. return False, ""
  356. def _chain_tool_call_wrappers(
  357. wrappers: Sequence[ToolCallWrapper],
  358. ) -> ToolCallWrapper | None:
  359. """Compose wrappers into middleware stack (first = outermost).
  360. Args:
  361. wrappers: Wrappers in middleware order.
  362. Returns:
  363. Composed wrapper, or `None` if empty.
  364. Example:
  365. wrapper = _chain_tool_call_wrappers([auth, cache, retry])
  366. # Request flows: auth -> cache -> retry -> tool
  367. # Response flows: tool -> retry -> cache -> auth
  368. """
  369. if not wrappers:
  370. return None
  371. if len(wrappers) == 1:
  372. return wrappers[0]
  373. def compose_two(outer: ToolCallWrapper, inner: ToolCallWrapper) -> ToolCallWrapper:
  374. """Compose two wrappers where outer wraps inner."""
  375. def composed(
  376. request: ToolCallRequest,
  377. execute: Callable[[ToolCallRequest], ToolMessage | Command],
  378. ) -> ToolMessage | Command:
  379. # Create a callable that invokes inner with the original execute
  380. def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
  381. return inner(req, execute)
  382. # Outer can call call_inner multiple times
  383. return outer(request, call_inner)
  384. return composed
  385. # Chain all wrappers: first -> second -> ... -> last
  386. result = wrappers[-1]
  387. for wrapper in reversed(wrappers[:-1]):
  388. result = compose_two(wrapper, result)
  389. return result
  390. def _chain_async_tool_call_wrappers(
  391. wrappers: Sequence[
  392. Callable[
  393. [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
  394. Awaitable[ToolMessage | Command],
  395. ]
  396. ],
  397. ) -> (
  398. Callable[
  399. [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
  400. Awaitable[ToolMessage | Command],
  401. ]
  402. | None
  403. ):
  404. """Compose async wrappers into middleware stack (first = outermost).
  405. Args:
  406. wrappers: Async wrappers in middleware order.
  407. Returns:
  408. Composed async wrapper, or `None` if empty.
  409. """
  410. if not wrappers:
  411. return None
  412. if len(wrappers) == 1:
  413. return wrappers[0]
  414. def compose_two(
  415. outer: Callable[
  416. [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
  417. Awaitable[ToolMessage | Command],
  418. ],
  419. inner: Callable[
  420. [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
  421. Awaitable[ToolMessage | Command],
  422. ],
  423. ) -> Callable[
  424. [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
  425. Awaitable[ToolMessage | Command],
  426. ]:
  427. """Compose two async wrappers where outer wraps inner."""
  428. async def composed(
  429. request: ToolCallRequest,
  430. execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
  431. ) -> ToolMessage | Command:
  432. # Create an async callable that invokes inner with the original execute
  433. async def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
  434. return await inner(req, execute)
  435. # Outer can call call_inner multiple times
  436. return await outer(request, call_inner)
  437. return composed
  438. # Chain all wrappers: first -> second -> ... -> last
  439. result = wrappers[-1]
  440. for wrapper in reversed(wrappers[:-1]):
  441. result = compose_two(wrapper, result)
  442. return result
  443. def create_agent( # noqa: PLR0915
  444. model: str | BaseChatModel,
  445. tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
  446. *,
  447. system_prompt: str | SystemMessage | None = None,
  448. middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
  449. response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
  450. state_schema: type[AgentState[ResponseT]] | None = None,
  451. context_schema: type[ContextT] | None = None,
  452. checkpointer: Checkpointer | None = None,
  453. store: BaseStore | None = None,
  454. interrupt_before: list[str] | None = None,
  455. interrupt_after: list[str] | None = None,
  456. debug: bool = False,
  457. name: str | None = None,
  458. cache: BaseCache | None = None,
  459. ) -> CompiledStateGraph[
  460. AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
  461. ]:
  462. """Creates an agent graph that calls tools in a loop until a stopping condition is met.
  463. For more details on using `create_agent`,
  464. visit the [Agents](https://docs.langchain.com/oss/python/langchain/agents) docs.
  465. Args:
  466. model: The language model for the agent.
  467. Can be a string identifier (e.g., `"openai:gpt-4"`) or a direct chat model
  468. instance (e.g., [`ChatOpenAI`][langchain_openai.ChatOpenAI] or other another
  469. [LangChain chat model](https://docs.langchain.com/oss/python/integrations/chat)).
  470. For a full list of supported model strings, see
  471. [`init_chat_model`][langchain.chat_models.init_chat_model(model_provider)].
  472. !!! tip ""
  473. See the [Models](https://docs.langchain.com/oss/python/langchain/models)
  474. docs for more information.
  475. tools: A list of tools, `dict`, or `Callable`.
  476. If `None` or an empty list, the agent will consist of a model node without a
  477. tool calling loop.
  478. !!! tip ""
  479. See the [Tools](https://docs.langchain.com/oss/python/langchain/tools)
  480. docs for more information.
  481. system_prompt: An optional system prompt for the LLM.
  482. Can be a `str` (which will be converted to a `SystemMessage`) or a
  483. `SystemMessage` instance directly. The system message is added to the
  484. beginning of the message list when calling the model.
  485. middleware: A sequence of middleware instances to apply to the agent.
  486. Middleware can intercept and modify agent behavior at various stages.
  487. !!! tip ""
  488. See the [Middleware](https://docs.langchain.com/oss/python/langchain/middleware)
  489. docs for more information.
  490. response_format: An optional configuration for structured responses.
  491. Can be a `ToolStrategy`, `ProviderStrategy`, or a Pydantic model class.
  492. If provided, the agent will handle structured output during the
  493. conversation flow.
  494. Raw schemas will be wrapped in an appropriate strategy based on model
  495. capabilities.
  496. !!! tip ""
  497. See the [Structured output](https://docs.langchain.com/oss/python/langchain/structured-output)
  498. docs for more information.
  499. state_schema: An optional `TypedDict` schema that extends `AgentState`.
  500. When provided, this schema is used instead of `AgentState` as the base
  501. schema for merging with middleware state schemas. This allows users to
  502. add custom state fields without needing to create custom middleware.
  503. Generally, it's recommended to use `state_schema` extensions via middleware
  504. to keep relevant extensions scoped to corresponding hooks / tools.
  505. context_schema: An optional schema for runtime context.
  506. checkpointer: An optional checkpoint saver object.
  507. Used for persisting the state of the graph (e.g., as chat memory) for a
  508. single thread (e.g., a single conversation).
  509. store: An optional store object.
  510. Used for persisting data across multiple threads (e.g., multiple
  511. conversations / users).
  512. interrupt_before: An optional list of node names to interrupt before.
  513. Useful if you want to add a user confirmation or other interrupt
  514. before taking an action.
  515. interrupt_after: An optional list of node names to interrupt after.
  516. Useful if you want to return directly or run additional processing
  517. on an output.
  518. debug: Whether to enable verbose logging for graph execution.
  519. When enabled, prints detailed information about each node execution, state
  520. updates, and transitions during agent runtime. Useful for debugging
  521. middleware behavior and understanding agent execution flow.
  522. name: An optional name for the `CompiledStateGraph`.
  523. This name will be automatically used when adding the agent graph to
  524. another graph as a subgraph node - particularly useful for building
  525. multi-agent systems.
  526. cache: An optional `BaseCache` instance to enable caching of graph execution.
  527. Returns:
  528. A compiled `StateGraph` that can be used for chat interactions.
  529. The agent node calls the language model with the messages list (after applying
  530. the system prompt). If the resulting [`AIMessage`][langchain.messages.AIMessage]
  531. contains `tool_calls`, the graph will then call the tools. The tools node executes
  532. the tools and adds the responses to the messages list as
  533. [`ToolMessage`][langchain.messages.ToolMessage] objects. The agent node then calls
  534. the language model again. The process repeats until no more `tool_calls` are present
  535. in the response. The agent then returns the full list of messages.
  536. Example:
  537. ```python
  538. from langchain.agents import create_agent
  539. def check_weather(location: str) -> str:
  540. '''Return the weather forecast for the specified location.'''
  541. return f"It's always sunny in {location}"
  542. graph = create_agent(
  543. model="anthropic:claude-sonnet-4-5-20250929",
  544. tools=[check_weather],
  545. system_prompt="You are a helpful assistant",
  546. )
  547. inputs = {"messages": [{"role": "user", "content": "what is the weather in sf"}]}
  548. for chunk in graph.stream(inputs, stream_mode="updates"):
  549. print(chunk)
  550. ```
  551. """
  552. # init chat model
  553. if isinstance(model, str):
  554. model = init_chat_model(model)
  555. # Convert system_prompt to SystemMessage if needed
  556. system_message: SystemMessage | None = None
  557. if system_prompt is not None:
  558. if isinstance(system_prompt, SystemMessage):
  559. system_message = system_prompt
  560. else:
  561. system_message = SystemMessage(content=system_prompt)
  562. # Handle tools being None or empty
  563. if tools is None:
  564. tools = []
  565. # Convert response format and setup structured output tools
  566. # Raw schemas are wrapped in AutoStrategy to preserve auto-detection intent.
  567. # AutoStrategy is converted to ToolStrategy upfront to calculate tools during agent creation,
  568. # but may be replaced with ProviderStrategy later based on model capabilities.
  569. initial_response_format: ToolStrategy | ProviderStrategy | AutoStrategy | None
  570. if response_format is None:
  571. initial_response_format = None
  572. elif isinstance(response_format, (ToolStrategy, ProviderStrategy)):
  573. # Preserve explicitly requested strategies
  574. initial_response_format = response_format
  575. elif isinstance(response_format, AutoStrategy):
  576. # AutoStrategy provided - preserve it for later auto-detection
  577. initial_response_format = response_format
  578. else:
  579. # Raw schema - wrap in AutoStrategy to enable auto-detection
  580. initial_response_format = AutoStrategy(schema=response_format)
  581. # For AutoStrategy, convert to ToolStrategy to setup tools upfront
  582. # (may be replaced with ProviderStrategy later based on model)
  583. tool_strategy_for_setup: ToolStrategy | None = None
  584. if isinstance(initial_response_format, AutoStrategy):
  585. tool_strategy_for_setup = ToolStrategy(schema=initial_response_format.schema)
  586. elif isinstance(initial_response_format, ToolStrategy):
  587. tool_strategy_for_setup = initial_response_format
  588. structured_output_tools: dict[str, OutputToolBinding] = {}
  589. if tool_strategy_for_setup:
  590. for response_schema in tool_strategy_for_setup.schema_specs:
  591. structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
  592. structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
  593. middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
  594. # Collect middleware with wrap_tool_call or awrap_tool_call hooks
  595. # Include middleware with either implementation to ensure NotImplementedError is raised
  596. # when middleware doesn't support the execution path
  597. middleware_w_wrap_tool_call = [
  598. m
  599. for m in middleware
  600. if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
  601. or m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
  602. ]
  603. # Chain all wrap_tool_call handlers into a single composed handler
  604. wrap_tool_call_wrapper = None
  605. if middleware_w_wrap_tool_call:
  606. wrappers = [m.wrap_tool_call for m in middleware_w_wrap_tool_call]
  607. wrap_tool_call_wrapper = _chain_tool_call_wrappers(wrappers)
  608. # Collect middleware with awrap_tool_call or wrap_tool_call hooks
  609. # Include middleware with either implementation to ensure NotImplementedError is raised
  610. # when middleware doesn't support the execution path
  611. middleware_w_awrap_tool_call = [
  612. m
  613. for m in middleware
  614. if m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
  615. or m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
  616. ]
  617. # Chain all awrap_tool_call handlers into a single composed async handler
  618. awrap_tool_call_wrapper = None
  619. if middleware_w_awrap_tool_call:
  620. async_wrappers = [m.awrap_tool_call for m in middleware_w_awrap_tool_call]
  621. awrap_tool_call_wrapper = _chain_async_tool_call_wrappers(async_wrappers)
  622. # Setup tools
  623. tool_node: ToolNode | None = None
  624. # Extract built-in provider tools (dict format) and regular tools (BaseTool/callables)
  625. built_in_tools = [t for t in tools if isinstance(t, dict)]
  626. regular_tools = [t for t in tools if not isinstance(t, dict)]
  627. # Tools that require client-side execution (must be in ToolNode)
  628. available_tools = middleware_tools + regular_tools
  629. # Only create ToolNode if we have client-side tools
  630. tool_node = (
  631. ToolNode(
  632. tools=available_tools,
  633. wrap_tool_call=wrap_tool_call_wrapper,
  634. awrap_tool_call=awrap_tool_call_wrapper,
  635. )
  636. if available_tools
  637. else None
  638. )
  639. # Default tools for ModelRequest initialization
  640. # Use converted BaseTool instances from ToolNode (not raw callables)
  641. # Include built-ins and converted tools (can be changed dynamically by middleware)
  642. # Structured tools are NOT included - they're added dynamically based on response_format
  643. if tool_node:
  644. default_tools = list(tool_node.tools_by_name.values()) + built_in_tools
  645. else:
  646. default_tools = list(built_in_tools)
  647. # validate middleware
  648. assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
  649. "Please remove duplicate middleware instances."
  650. )
  651. middleware_w_before_agent = [
  652. m
  653. for m in middleware
  654. if m.__class__.before_agent is not AgentMiddleware.before_agent
  655. or m.__class__.abefore_agent is not AgentMiddleware.abefore_agent
  656. ]
  657. middleware_w_before_model = [
  658. m
  659. for m in middleware
  660. if m.__class__.before_model is not AgentMiddleware.before_model
  661. or m.__class__.abefore_model is not AgentMiddleware.abefore_model
  662. ]
  663. middleware_w_after_model = [
  664. m
  665. for m in middleware
  666. if m.__class__.after_model is not AgentMiddleware.after_model
  667. or m.__class__.aafter_model is not AgentMiddleware.aafter_model
  668. ]
  669. middleware_w_after_agent = [
  670. m
  671. for m in middleware
  672. if m.__class__.after_agent is not AgentMiddleware.after_agent
  673. or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
  674. ]
  675. # Collect middleware with wrap_model_call or awrap_model_call hooks
  676. # Include middleware with either implementation to ensure NotImplementedError is raised
  677. # when middleware doesn't support the execution path
  678. middleware_w_wrap_model_call = [
  679. m
  680. for m in middleware
  681. if m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
  682. or m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
  683. ]
  684. # Collect middleware with awrap_model_call or wrap_model_call hooks
  685. # Include middleware with either implementation to ensure NotImplementedError is raised
  686. # when middleware doesn't support the execution path
  687. middleware_w_awrap_model_call = [
  688. m
  689. for m in middleware
  690. if m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
  691. or m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
  692. ]
  693. # Compose wrap_model_call handlers into a single middleware stack (sync)
  694. wrap_model_call_handler = None
  695. if middleware_w_wrap_model_call:
  696. sync_handlers = [m.wrap_model_call for m in middleware_w_wrap_model_call]
  697. wrap_model_call_handler = _chain_model_call_handlers(sync_handlers)
  698. # Compose awrap_model_call handlers into a single middleware stack (async)
  699. awrap_model_call_handler = None
  700. if middleware_w_awrap_model_call:
  701. async_handlers = [m.awrap_model_call for m in middleware_w_awrap_model_call]
  702. awrap_model_call_handler = _chain_async_model_call_handlers(async_handlers)
  703. state_schemas: set[type] = {m.state_schema for m in middleware}
  704. # Use provided state_schema if available, otherwise use base AgentState
  705. base_state = state_schema if state_schema is not None else AgentState
  706. state_schemas.add(base_state)
  707. resolved_state_schema = _resolve_schema(state_schemas, "StateSchema", None)
  708. input_schema = _resolve_schema(state_schemas, "InputSchema", "input")
  709. output_schema = _resolve_schema(state_schemas, "OutputSchema", "output")
  710. # create graph, add nodes
  711. graph: StateGraph[
  712. AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
  713. ] = StateGraph(
  714. state_schema=resolved_state_schema,
  715. input_schema=input_schema,
  716. output_schema=output_schema,
  717. context_schema=context_schema,
  718. )
  719. def _handle_model_output(
  720. output: AIMessage, effective_response_format: ResponseFormat | None
  721. ) -> dict[str, Any]:
  722. """Handle model output including structured responses.
  723. Args:
  724. output: The AI message output from the model.
  725. effective_response_format: The actual strategy used
  726. (may differ from initial if auto-detected).
  727. """
  728. # Handle structured output with provider strategy
  729. if isinstance(effective_response_format, ProviderStrategy):
  730. if not output.tool_calls:
  731. provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
  732. effective_response_format.schema_spec
  733. )
  734. try:
  735. structured_response = provider_strategy_binding.parse(output)
  736. except Exception as exc: # noqa: BLE001
  737. schema_name = getattr(
  738. effective_response_format.schema_spec.schema, "__name__", "response_format"
  739. )
  740. validation_error = StructuredOutputValidationError(schema_name, exc, output)
  741. raise validation_error
  742. else:
  743. return {"messages": [output], "structured_response": structured_response}
  744. return {"messages": [output]}
  745. # Handle structured output with tool strategy
  746. if (
  747. isinstance(effective_response_format, ToolStrategy)
  748. and isinstance(output, AIMessage)
  749. and output.tool_calls
  750. ):
  751. structured_tool_calls = [
  752. tc for tc in output.tool_calls if tc["name"] in structured_output_tools
  753. ]
  754. if structured_tool_calls:
  755. exception: StructuredOutputError | None = None
  756. if len(structured_tool_calls) > 1:
  757. # Handle multiple structured outputs error
  758. tool_names = [tc["name"] for tc in structured_tool_calls]
  759. exception = MultipleStructuredOutputsError(tool_names, output)
  760. should_retry, error_message = _handle_structured_output_error(
  761. exception, effective_response_format
  762. )
  763. if not should_retry:
  764. raise exception
  765. # Add error messages and retry
  766. tool_messages = [
  767. ToolMessage(
  768. content=error_message,
  769. tool_call_id=tc["id"],
  770. name=tc["name"],
  771. )
  772. for tc in structured_tool_calls
  773. ]
  774. return {"messages": [output, *tool_messages]}
  775. # Handle single structured output
  776. tool_call = structured_tool_calls[0]
  777. try:
  778. structured_tool_binding = structured_output_tools[tool_call["name"]]
  779. structured_response = structured_tool_binding.parse(tool_call["args"])
  780. tool_message_content = (
  781. effective_response_format.tool_message_content
  782. if effective_response_format.tool_message_content
  783. else f"Returning structured response: {structured_response}"
  784. )
  785. return {
  786. "messages": [
  787. output,
  788. ToolMessage(
  789. content=tool_message_content,
  790. tool_call_id=tool_call["id"],
  791. name=tool_call["name"],
  792. ),
  793. ],
  794. "structured_response": structured_response,
  795. }
  796. except Exception as exc: # noqa: BLE001
  797. exception = StructuredOutputValidationError(tool_call["name"], exc, output)
  798. should_retry, error_message = _handle_structured_output_error(
  799. exception, effective_response_format
  800. )
  801. if not should_retry:
  802. raise exception
  803. return {
  804. "messages": [
  805. output,
  806. ToolMessage(
  807. content=error_message,
  808. tool_call_id=tool_call["id"],
  809. name=tool_call["name"],
  810. ),
  811. ],
  812. }
  813. return {"messages": [output]}
  814. def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
  815. """Get the model with appropriate tool bindings.
  816. Performs auto-detection of strategy if needed based on model capabilities.
  817. Args:
  818. request: The model request containing model, tools, and response format.
  819. Returns:
  820. Tuple of `(bound_model, effective_response_format)` where
  821. `effective_response_format` is the actual strategy used (may differ from
  822. initial if auto-detected).
  823. """
  824. # Validate ONLY client-side tools that need to exist in tool_node
  825. # Build map of available client-side tools from the ToolNode
  826. # (which has already converted callables)
  827. available_tools_by_name = {}
  828. if tool_node:
  829. available_tools_by_name = tool_node.tools_by_name.copy()
  830. # Check if any requested tools are unknown CLIENT-SIDE tools
  831. unknown_tool_names = []
  832. for t in request.tools:
  833. # Only validate BaseTool instances (skip built-in dict tools)
  834. if isinstance(t, dict):
  835. continue
  836. if isinstance(t, BaseTool) and t.name not in available_tools_by_name:
  837. unknown_tool_names.append(t.name)
  838. if unknown_tool_names:
  839. available_tool_names = sorted(available_tools_by_name.keys())
  840. msg = (
  841. f"Middleware returned unknown tool names: {unknown_tool_names}\n\n"
  842. f"Available client-side tools: {available_tool_names}\n\n"
  843. "To fix this issue:\n"
  844. "1. Ensure the tools are passed to create_agent() via "
  845. "the 'tools' parameter\n"
  846. "2. If using custom middleware with tools, ensure "
  847. "they're registered via middleware.tools attribute\n"
  848. "3. Verify that tool names in ModelRequest.tools match "
  849. "the actual tool.name values\n"
  850. "Note: Built-in provider tools (dict format) can be added dynamically."
  851. )
  852. raise ValueError(msg)
  853. # Determine effective response format (auto-detect if needed)
  854. effective_response_format: ResponseFormat | None
  855. if isinstance(request.response_format, AutoStrategy):
  856. # User provided raw schema via AutoStrategy - auto-detect best strategy based on model
  857. if _supports_provider_strategy(request.model, tools=request.tools):
  858. # Model supports provider strategy - use it
  859. effective_response_format = ProviderStrategy(schema=request.response_format.schema)
  860. else:
  861. # Model doesn't support provider strategy - use ToolStrategy
  862. effective_response_format = ToolStrategy(schema=request.response_format.schema)
  863. else:
  864. # User explicitly specified a strategy - preserve it
  865. effective_response_format = request.response_format
  866. # Build final tools list including structured output tools
  867. # request.tools now only contains BaseTool instances (converted from callables)
  868. # and dicts (built-ins)
  869. final_tools = list(request.tools)
  870. if isinstance(effective_response_format, ToolStrategy):
  871. # Add structured output tools to final tools list
  872. structured_tools = [info.tool for info in structured_output_tools.values()]
  873. final_tools.extend(structured_tools)
  874. # Bind model based on effective response format
  875. if isinstance(effective_response_format, ProviderStrategy):
  876. # (Backward compatibility) Use OpenAI format structured output
  877. kwargs = effective_response_format.to_model_kwargs()
  878. return (
  879. request.model.bind_tools(
  880. final_tools, strict=True, **kwargs, **request.model_settings
  881. ),
  882. effective_response_format,
  883. )
  884. if isinstance(effective_response_format, ToolStrategy):
  885. # Current implementation requires that tools used for structured output
  886. # have to be declared upfront when creating the agent as part of the
  887. # response format. Middleware is allowed to change the response format
  888. # to a subset of the original structured tools when using ToolStrategy,
  889. # but not to add new structured tools that weren't declared upfront.
  890. # Compute output binding
  891. for tc in effective_response_format.schema_specs:
  892. if tc.name not in structured_output_tools:
  893. msg = (
  894. f"ToolStrategy specifies tool '{tc.name}' "
  895. "which wasn't declared in the original "
  896. "response format when creating the agent."
  897. )
  898. raise ValueError(msg)
  899. # Force tool use if we have structured output tools
  900. tool_choice = "any" if structured_output_tools else request.tool_choice
  901. return (
  902. request.model.bind_tools(
  903. final_tools, tool_choice=tool_choice, **request.model_settings
  904. ),
  905. effective_response_format,
  906. )
  907. # No structured output - standard model binding
  908. if final_tools:
  909. return (
  910. request.model.bind_tools(
  911. final_tools, tool_choice=request.tool_choice, **request.model_settings
  912. ),
  913. None,
  914. )
  915. return request.model.bind(**request.model_settings), None
  916. def _execute_model_sync(request: ModelRequest) -> ModelResponse:
  917. """Execute model and return response.
  918. This is the core model execution logic wrapped by `wrap_model_call` handlers.
  919. Raises any exceptions that occur during model invocation.
  920. """
  921. # Get the bound model (with auto-detection if needed)
  922. model_, effective_response_format = _get_bound_model(request)
  923. messages = request.messages
  924. if request.system_message:
  925. messages = [request.system_message, *messages]
  926. output = model_.invoke(messages)
  927. if name:
  928. output.name = name
  929. # Handle model output to get messages and structured_response
  930. handled_output = _handle_model_output(output, effective_response_format)
  931. messages_list = handled_output["messages"]
  932. structured_response = handled_output.get("structured_response")
  933. return ModelResponse(
  934. result=messages_list,
  935. structured_response=structured_response,
  936. )
  937. def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
  938. """Sync model request handler with sequential middleware processing."""
  939. request = ModelRequest(
  940. model=model,
  941. tools=default_tools,
  942. system_message=system_message,
  943. response_format=initial_response_format,
  944. messages=state["messages"],
  945. tool_choice=None,
  946. state=state,
  947. runtime=runtime,
  948. )
  949. if wrap_model_call_handler is None:
  950. # No handlers - execute directly
  951. response = _execute_model_sync(request)
  952. else:
  953. # Call composed handler with base handler
  954. response = wrap_model_call_handler(request, _execute_model_sync)
  955. # Extract state updates from ModelResponse
  956. state_updates = {"messages": response.result}
  957. if response.structured_response is not None:
  958. state_updates["structured_response"] = response.structured_response
  959. return state_updates
  960. async def _execute_model_async(request: ModelRequest) -> ModelResponse:
  961. """Execute model asynchronously and return response.
  962. This is the core async model execution logic wrapped by `wrap_model_call`
  963. handlers.
  964. Raises any exceptions that occur during model invocation.
  965. """
  966. # Get the bound model (with auto-detection if needed)
  967. model_, effective_response_format = _get_bound_model(request)
  968. messages = request.messages
  969. if request.system_message:
  970. messages = [request.system_message, *messages]
  971. output = await model_.ainvoke(messages)
  972. if name:
  973. output.name = name
  974. # Handle model output to get messages and structured_response
  975. handled_output = _handle_model_output(output, effective_response_format)
  976. messages_list = handled_output["messages"]
  977. structured_response = handled_output.get("structured_response")
  978. return ModelResponse(
  979. result=messages_list,
  980. structured_response=structured_response,
  981. )
  982. async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
  983. """Async model request handler with sequential middleware processing."""
  984. request = ModelRequest(
  985. model=model,
  986. tools=default_tools,
  987. system_message=system_message,
  988. response_format=initial_response_format,
  989. messages=state["messages"],
  990. tool_choice=None,
  991. state=state,
  992. runtime=runtime,
  993. )
  994. if awrap_model_call_handler is None:
  995. # No async handlers - execute directly
  996. response = await _execute_model_async(request)
  997. else:
  998. # Call composed async handler with base handler
  999. response = await awrap_model_call_handler(request, _execute_model_async)
  1000. # Extract state updates from ModelResponse
  1001. state_updates = {"messages": response.result}
  1002. if response.structured_response is not None:
  1003. state_updates["structured_response"] = response.structured_response
  1004. return state_updates
  1005. # Use sync or async based on model capabilities
  1006. graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))
  1007. # Only add tools node if we have tools
  1008. if tool_node is not None:
  1009. graph.add_node("tools", tool_node)
  1010. # Add middleware nodes
  1011. for m in middleware:
  1012. if (
  1013. m.__class__.before_agent is not AgentMiddleware.before_agent
  1014. or m.__class__.abefore_agent is not AgentMiddleware.abefore_agent
  1015. ):
  1016. # Use RunnableCallable to support both sync and async
  1017. # Pass None for sync if not overridden to avoid signature conflicts
  1018. sync_before_agent = (
  1019. m.before_agent
  1020. if m.__class__.before_agent is not AgentMiddleware.before_agent
  1021. else None
  1022. )
  1023. async_before_agent = (
  1024. m.abefore_agent
  1025. if m.__class__.abefore_agent is not AgentMiddleware.abefore_agent
  1026. else None
  1027. )
  1028. before_agent_node = RunnableCallable(sync_before_agent, async_before_agent, trace=False)
  1029. graph.add_node(
  1030. f"{m.name}.before_agent", before_agent_node, input_schema=resolved_state_schema
  1031. )
  1032. if (
  1033. m.__class__.before_model is not AgentMiddleware.before_model
  1034. or m.__class__.abefore_model is not AgentMiddleware.abefore_model
  1035. ):
  1036. # Use RunnableCallable to support both sync and async
  1037. # Pass None for sync if not overridden to avoid signature conflicts
  1038. sync_before = (
  1039. m.before_model
  1040. if m.__class__.before_model is not AgentMiddleware.before_model
  1041. else None
  1042. )
  1043. async_before = (
  1044. m.abefore_model
  1045. if m.__class__.abefore_model is not AgentMiddleware.abefore_model
  1046. else None
  1047. )
  1048. before_node = RunnableCallable(sync_before, async_before, trace=False)
  1049. graph.add_node(
  1050. f"{m.name}.before_model", before_node, input_schema=resolved_state_schema
  1051. )
  1052. if (
  1053. m.__class__.after_model is not AgentMiddleware.after_model
  1054. or m.__class__.aafter_model is not AgentMiddleware.aafter_model
  1055. ):
  1056. # Use RunnableCallable to support both sync and async
  1057. # Pass None for sync if not overridden to avoid signature conflicts
  1058. sync_after = (
  1059. m.after_model
  1060. if m.__class__.after_model is not AgentMiddleware.after_model
  1061. else None
  1062. )
  1063. async_after = (
  1064. m.aafter_model
  1065. if m.__class__.aafter_model is not AgentMiddleware.aafter_model
  1066. else None
  1067. )
  1068. after_node = RunnableCallable(sync_after, async_after, trace=False)
  1069. graph.add_node(f"{m.name}.after_model", after_node, input_schema=resolved_state_schema)
  1070. if (
  1071. m.__class__.after_agent is not AgentMiddleware.after_agent
  1072. or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
  1073. ):
  1074. # Use RunnableCallable to support both sync and async
  1075. # Pass None for sync if not overridden to avoid signature conflicts
  1076. sync_after_agent = (
  1077. m.after_agent
  1078. if m.__class__.after_agent is not AgentMiddleware.after_agent
  1079. else None
  1080. )
  1081. async_after_agent = (
  1082. m.aafter_agent
  1083. if m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
  1084. else None
  1085. )
  1086. after_agent_node = RunnableCallable(sync_after_agent, async_after_agent, trace=False)
  1087. graph.add_node(
  1088. f"{m.name}.after_agent", after_agent_node, input_schema=resolved_state_schema
  1089. )
  1090. # Determine the entry node (runs once at start): before_agent -> before_model -> model
  1091. if middleware_w_before_agent:
  1092. entry_node = f"{middleware_w_before_agent[0].name}.before_agent"
  1093. elif middleware_w_before_model:
  1094. entry_node = f"{middleware_w_before_model[0].name}.before_model"
  1095. else:
  1096. entry_node = "model"
  1097. # Determine the loop entry node (beginning of agent loop, excludes before_agent)
  1098. # This is where tools will loop back to for the next iteration
  1099. if middleware_w_before_model:
  1100. loop_entry_node = f"{middleware_w_before_model[0].name}.before_model"
  1101. else:
  1102. loop_entry_node = "model"
  1103. # Determine the loop exit node (end of each iteration, can run multiple times)
  1104. # This is after_model or model, but NOT after_agent
  1105. if middleware_w_after_model:
  1106. loop_exit_node = f"{middleware_w_after_model[0].name}.after_model"
  1107. else:
  1108. loop_exit_node = "model"
  1109. # Determine the exit node (runs once at end): after_agent or END
  1110. if middleware_w_after_agent:
  1111. exit_node = f"{middleware_w_after_agent[-1].name}.after_agent"
  1112. else:
  1113. exit_node = END
  1114. graph.add_edge(START, entry_node)
  1115. # add conditional edges only if tools exist
  1116. if tool_node is not None:
  1117. # Only include exit_node in destinations if any tool has return_direct=True
  1118. # or if there are structured output tools
  1119. tools_to_model_destinations = [loop_entry_node]
  1120. if (
  1121. any(tool.return_direct for tool in tool_node.tools_by_name.values())
  1122. or structured_output_tools
  1123. ):
  1124. tools_to_model_destinations.append(exit_node)
  1125. graph.add_conditional_edges(
  1126. "tools",
  1127. RunnableCallable(
  1128. _make_tools_to_model_edge(
  1129. tool_node=tool_node,
  1130. model_destination=loop_entry_node,
  1131. structured_output_tools=structured_output_tools,
  1132. end_destination=exit_node,
  1133. ),
  1134. trace=False,
  1135. ),
  1136. tools_to_model_destinations,
  1137. )
  1138. # base destinations are tools and exit_node
  1139. # we add the loop_entry node to edge destinations if:
  1140. # - there is an after model hook(s) -- allows jump_to to model
  1141. # potentially artificially injected tool messages, ex HITL
  1142. # - there is a response format -- to allow for jumping to model to handle
  1143. # regenerating structured output tool calls
  1144. model_to_tools_destinations = ["tools", exit_node]
  1145. if response_format or loop_exit_node != "model":
  1146. model_to_tools_destinations.append(loop_entry_node)
  1147. graph.add_conditional_edges(
  1148. loop_exit_node,
  1149. RunnableCallable(
  1150. _make_model_to_tools_edge(
  1151. model_destination=loop_entry_node,
  1152. structured_output_tools=structured_output_tools,
  1153. end_destination=exit_node,
  1154. ),
  1155. trace=False,
  1156. ),
  1157. model_to_tools_destinations,
  1158. )
  1159. elif len(structured_output_tools) > 0:
  1160. graph.add_conditional_edges(
  1161. loop_exit_node,
  1162. RunnableCallable(
  1163. _make_model_to_model_edge(
  1164. model_destination=loop_entry_node,
  1165. end_destination=exit_node,
  1166. ),
  1167. trace=False,
  1168. ),
  1169. [loop_entry_node, exit_node],
  1170. )
  1171. elif loop_exit_node == "model":
  1172. # If no tools and no after_model, go directly to exit_node
  1173. graph.add_edge(loop_exit_node, exit_node)
  1174. # No tools but we have after_model - connect after_model to exit_node
  1175. else:
  1176. _add_middleware_edge(
  1177. graph,
  1178. name=f"{middleware_w_after_model[0].name}.after_model",
  1179. default_destination=exit_node,
  1180. model_destination=loop_entry_node,
  1181. end_destination=exit_node,
  1182. can_jump_to=_get_can_jump_to(middleware_w_after_model[0], "after_model"),
  1183. )
  1184. # Add before_agent middleware edges
  1185. if middleware_w_before_agent:
  1186. for m1, m2 in itertools.pairwise(middleware_w_before_agent):
  1187. _add_middleware_edge(
  1188. graph,
  1189. name=f"{m1.name}.before_agent",
  1190. default_destination=f"{m2.name}.before_agent",
  1191. model_destination=loop_entry_node,
  1192. end_destination=exit_node,
  1193. can_jump_to=_get_can_jump_to(m1, "before_agent"),
  1194. )
  1195. # Connect last before_agent to loop_entry_node (before_model or model)
  1196. _add_middleware_edge(
  1197. graph,
  1198. name=f"{middleware_w_before_agent[-1].name}.before_agent",
  1199. default_destination=loop_entry_node,
  1200. model_destination=loop_entry_node,
  1201. end_destination=exit_node,
  1202. can_jump_to=_get_can_jump_to(middleware_w_before_agent[-1], "before_agent"),
  1203. )
  1204. # Add before_model middleware edges
  1205. if middleware_w_before_model:
  1206. for m1, m2 in itertools.pairwise(middleware_w_before_model):
  1207. _add_middleware_edge(
  1208. graph,
  1209. name=f"{m1.name}.before_model",
  1210. default_destination=f"{m2.name}.before_model",
  1211. model_destination=loop_entry_node,
  1212. end_destination=exit_node,
  1213. can_jump_to=_get_can_jump_to(m1, "before_model"),
  1214. )
  1215. # Go directly to model after the last before_model
  1216. _add_middleware_edge(
  1217. graph,
  1218. name=f"{middleware_w_before_model[-1].name}.before_model",
  1219. default_destination="model",
  1220. model_destination=loop_entry_node,
  1221. end_destination=exit_node,
  1222. can_jump_to=_get_can_jump_to(middleware_w_before_model[-1], "before_model"),
  1223. )
  1224. # Add after_model middleware edges
  1225. if middleware_w_after_model:
  1226. graph.add_edge("model", f"{middleware_w_after_model[-1].name}.after_model")
  1227. for idx in range(len(middleware_w_after_model) - 1, 0, -1):
  1228. m1 = middleware_w_after_model[idx]
  1229. m2 = middleware_w_after_model[idx - 1]
  1230. _add_middleware_edge(
  1231. graph,
  1232. name=f"{m1.name}.after_model",
  1233. default_destination=f"{m2.name}.after_model",
  1234. model_destination=loop_entry_node,
  1235. end_destination=exit_node,
  1236. can_jump_to=_get_can_jump_to(m1, "after_model"),
  1237. )
  1238. # Note: Connection from after_model to after_agent/END is handled above
  1239. # in the conditional edges section
  1240. # Add after_agent middleware edges
  1241. if middleware_w_after_agent:
  1242. # Chain after_agent middleware (runs once at the very end, before END)
  1243. for idx in range(len(middleware_w_after_agent) - 1, 0, -1):
  1244. m1 = middleware_w_after_agent[idx]
  1245. m2 = middleware_w_after_agent[idx - 1]
  1246. _add_middleware_edge(
  1247. graph,
  1248. name=f"{m1.name}.after_agent",
  1249. default_destination=f"{m2.name}.after_agent",
  1250. model_destination=loop_entry_node,
  1251. end_destination=exit_node,
  1252. can_jump_to=_get_can_jump_to(m1, "after_agent"),
  1253. )
  1254. # Connect the last after_agent to END
  1255. _add_middleware_edge(
  1256. graph,
  1257. name=f"{middleware_w_after_agent[0].name}.after_agent",
  1258. default_destination=END,
  1259. model_destination=loop_entry_node,
  1260. end_destination=exit_node,
  1261. can_jump_to=_get_can_jump_to(middleware_w_after_agent[0], "after_agent"),
  1262. )
  1263. return graph.compile(
  1264. checkpointer=checkpointer,
  1265. store=store,
  1266. interrupt_before=interrupt_before,
  1267. interrupt_after=interrupt_after,
  1268. debug=debug,
  1269. name=name,
  1270. cache=cache,
  1271. ).with_config({"recursion_limit": 10_000})
  1272. def _resolve_jump(
  1273. jump_to: JumpTo | None,
  1274. *,
  1275. model_destination: str,
  1276. end_destination: str,
  1277. ) -> str | None:
  1278. if jump_to == "model":
  1279. return model_destination
  1280. if jump_to == "end":
  1281. return end_destination
  1282. if jump_to == "tools":
  1283. return "tools"
  1284. return None
  1285. def _fetch_last_ai_and_tool_messages(
  1286. messages: list[AnyMessage],
  1287. ) -> tuple[AIMessage, list[ToolMessage]]:
  1288. last_ai_index: int
  1289. last_ai_message: AIMessage
  1290. for i in range(len(messages) - 1, -1, -1):
  1291. if isinstance(messages[i], AIMessage):
  1292. last_ai_index = i
  1293. last_ai_message = cast("AIMessage", messages[i])
  1294. break
  1295. tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)]
  1296. return last_ai_message, tool_messages
  1297. def _make_model_to_tools_edge(
  1298. *,
  1299. model_destination: str,
  1300. structured_output_tools: dict[str, OutputToolBinding],
  1301. end_destination: str,
  1302. ) -> Callable[[dict[str, Any]], str | list[Send] | None]:
  1303. def model_to_tools(
  1304. state: dict[str, Any],
  1305. ) -> str | list[Send] | None:
  1306. # 1. if there's an explicit jump_to in the state, use it
  1307. if jump_to := state.get("jump_to"):
  1308. return _resolve_jump(
  1309. jump_to,
  1310. model_destination=model_destination,
  1311. end_destination=end_destination,
  1312. )
  1313. last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
  1314. tool_message_ids = [m.tool_call_id for m in tool_messages]
  1315. # 2. if the model hasn't called any tools, exit the loop
  1316. # this is the classic exit condition for an agent loop
  1317. if len(last_ai_message.tool_calls) == 0:
  1318. return end_destination
  1319. pending_tool_calls = [
  1320. c
  1321. for c in last_ai_message.tool_calls
  1322. if c["id"] not in tool_message_ids and c["name"] not in structured_output_tools
  1323. ]
  1324. # 3. if there are pending tool calls, jump to the tool node
  1325. if pending_tool_calls:
  1326. return [
  1327. Send(
  1328. "tools",
  1329. ToolCallWithContext(
  1330. __type="tool_call_with_context",
  1331. tool_call=tool_call,
  1332. state=state,
  1333. ),
  1334. )
  1335. for tool_call in pending_tool_calls
  1336. ]
  1337. # 4. if there is a structured response, exit the loop
  1338. if "structured_response" in state:
  1339. return end_destination
  1340. # 5. AIMessage has tool calls, but there are no pending tool calls
  1341. # which suggests the injection of artificial tool messages. jump to the model node
  1342. return model_destination
  1343. return model_to_tools
  1344. def _make_model_to_model_edge(
  1345. *,
  1346. model_destination: str,
  1347. end_destination: str,
  1348. ) -> Callable[[dict[str, Any]], str | list[Send] | None]:
  1349. def model_to_model(
  1350. state: dict[str, Any],
  1351. ) -> str | list[Send] | None:
  1352. # 1. Priority: Check for explicit jump_to directive from middleware
  1353. if jump_to := state.get("jump_to"):
  1354. return _resolve_jump(
  1355. jump_to,
  1356. model_destination=model_destination,
  1357. end_destination=end_destination,
  1358. )
  1359. # 2. Exit condition: A structured response was generated
  1360. if "structured_response" in state:
  1361. return end_destination
  1362. # 3. Default: Continue the loop, there may have been an issue
  1363. # with structured output generation, so we need to retry
  1364. return model_destination
  1365. return model_to_model
  1366. def _make_tools_to_model_edge(
  1367. *,
  1368. tool_node: ToolNode,
  1369. model_destination: str,
  1370. structured_output_tools: dict[str, OutputToolBinding],
  1371. end_destination: str,
  1372. ) -> Callable[[dict[str, Any]], str | None]:
  1373. def tools_to_model(state: dict[str, Any]) -> str | None:
  1374. last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
  1375. # 1. Exit condition: All executed tools have return_direct=True
  1376. # Filter to only client-side tools (provider tools are not in tool_node)
  1377. client_side_tool_calls = [
  1378. c for c in last_ai_message.tool_calls if c["name"] in tool_node.tools_by_name
  1379. ]
  1380. if client_side_tool_calls and all(
  1381. tool_node.tools_by_name[c["name"]].return_direct for c in client_side_tool_calls
  1382. ):
  1383. return end_destination
  1384. # 2. Exit condition: A structured output tool was executed
  1385. if any(t.name in structured_output_tools for t in tool_messages):
  1386. return end_destination
  1387. # 3. Default: Continue the loop
  1388. # Tool execution completed successfully, route back to the model
  1389. # so it can process the tool results and decide the next action.
  1390. return model_destination
  1391. return tools_to_model
  1392. def _add_middleware_edge(
  1393. graph: StateGraph[
  1394. AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
  1395. ],
  1396. *,
  1397. name: str,
  1398. default_destination: str,
  1399. model_destination: str,
  1400. end_destination: str,
  1401. can_jump_to: list[JumpTo] | None,
  1402. ) -> None:
  1403. """Add an edge to the graph for a middleware node.
  1404. Args:
  1405. graph: The graph to add the edge to.
  1406. name: The name of the middleware node.
  1407. default_destination: The default destination for the edge.
  1408. model_destination: The destination for the edge to the model.
  1409. end_destination: The destination for the edge to the end.
  1410. can_jump_to: The conditionally jumpable destinations for the edge.
  1411. """
  1412. if can_jump_to:
  1413. def jump_edge(state: dict[str, Any]) -> str:
  1414. return (
  1415. _resolve_jump(
  1416. state.get("jump_to"),
  1417. model_destination=model_destination,
  1418. end_destination=end_destination,
  1419. )
  1420. or default_destination
  1421. )
  1422. destinations = [default_destination]
  1423. if "end" in can_jump_to:
  1424. destinations.append(end_destination)
  1425. if "tools" in can_jump_to:
  1426. destinations.append("tools")
  1427. if "model" in can_jump_to and name != model_destination:
  1428. destinations.append(model_destination)
  1429. graph.add_conditional_edges(name, RunnableCallable(jump_edge, trace=False), destinations)
  1430. else:
  1431. graph.add_edge(name, default_destination)
  1432. __all__ = [
  1433. "create_agent",
  1434. ]