structured_output.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. """Types for setting agent response formats."""
  2. from __future__ import annotations
  3. import uuid
  4. from dataclasses import dataclass, is_dataclass
  5. from types import UnionType
  6. from typing import (
  7. TYPE_CHECKING,
  8. Any,
  9. Generic,
  10. Literal,
  11. TypeVar,
  12. Union,
  13. get_args,
  14. get_origin,
  15. )
  16. from langchain_core.tools import BaseTool, StructuredTool
  17. from pydantic import BaseModel, TypeAdapter
  18. from typing_extensions import Self, is_typeddict
  19. if TYPE_CHECKING:
  20. from collections.abc import Callable, Iterable
  21. from langchain_core.messages import AIMessage
  22. # Supported schema types: Pydantic models, dataclasses, TypedDict, JSON schema dicts
  23. SchemaT = TypeVar("SchemaT")
  24. SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"]
  25. class StructuredOutputError(Exception):
  26. """Base class for structured output errors."""
  27. ai_message: AIMessage
  28. class MultipleStructuredOutputsError(StructuredOutputError):
  29. """Raised when model returns multiple structured output tool calls when only one is expected."""
  30. def __init__(self, tool_names: list[str], ai_message: AIMessage) -> None:
  31. """Initialize `MultipleStructuredOutputsError`.
  32. Args:
  33. tool_names: The names of the tools called for structured output.
  34. ai_message: The AI message that contained the invalid multiple tool calls.
  35. """
  36. self.tool_names = tool_names
  37. self.ai_message = ai_message
  38. super().__init__(
  39. "Model incorrectly returned multiple structured responses "
  40. f"({', '.join(tool_names)}) when only one is expected."
  41. )
  42. class StructuredOutputValidationError(StructuredOutputError):
  43. """Raised when structured output tool call arguments fail to parse according to the schema."""
  44. def __init__(self, tool_name: str, source: Exception, ai_message: AIMessage) -> None:
  45. """Initialize `StructuredOutputValidationError`.
  46. Args:
  47. tool_name: The name of the tool that failed.
  48. source: The exception that occurred.
  49. ai_message: The AI message that contained the invalid structured output.
  50. """
  51. self.tool_name = tool_name
  52. self.source = source
  53. self.ai_message = ai_message
  54. super().__init__(f"Failed to parse structured output for tool '{tool_name}': {source}.")
  55. def _parse_with_schema(
  56. schema: type[SchemaT] | dict, schema_kind: SchemaKind, data: dict[str, Any]
  57. ) -> Any:
  58. """Parse data using for any supported schema type.
  59. Args:
  60. schema: The schema type (Pydantic model, `dataclass`, or `TypedDict`)
  61. schema_kind: One of `"pydantic"`, `"dataclass"`, `"typeddict"`, or
  62. `"json_schema"`
  63. data: The data to parse
  64. Returns:
  65. The parsed instance according to the schema type
  66. Raises:
  67. ValueError: If parsing fails
  68. """
  69. if schema_kind == "json_schema":
  70. return data
  71. try:
  72. adapter: TypeAdapter[SchemaT] = TypeAdapter(schema)
  73. return adapter.validate_python(data)
  74. except Exception as e:
  75. schema_name = getattr(schema, "__name__", str(schema))
  76. msg = f"Failed to parse data to {schema_name}: {e}"
  77. raise ValueError(msg) from e
  78. @dataclass(init=False)
  79. class _SchemaSpec(Generic[SchemaT]):
  80. """Describes a structured output schema."""
  81. schema: type[SchemaT]
  82. """The schema for the response, can be a Pydantic model, `dataclass`, `TypedDict`,
  83. or JSON schema dict."""
  84. name: str
  85. """Name of the schema, used for tool calling.
  86. If not provided, the name will be the model name or `"response_format"` if it's a
  87. JSON schema.
  88. """
  89. description: str
  90. """Custom description of the schema.
  91. If not provided, provided will use the model's docstring.
  92. """
  93. schema_kind: SchemaKind
  94. """The kind of schema."""
  95. json_schema: dict[str, Any]
  96. """JSON schema associated with the schema."""
  97. strict: bool = False
  98. """Whether to enforce strict validation of the schema."""
  99. def __init__(
  100. self,
  101. schema: type[SchemaT],
  102. *,
  103. name: str | None = None,
  104. description: str | None = None,
  105. strict: bool = False,
  106. ) -> None:
  107. """Initialize SchemaSpec with schema and optional parameters."""
  108. self.schema = schema
  109. if name:
  110. self.name = name
  111. elif isinstance(schema, dict):
  112. self.name = str(schema.get("title", f"response_format_{str(uuid.uuid4())[:4]}"))
  113. else:
  114. self.name = str(getattr(schema, "__name__", f"response_format_{str(uuid.uuid4())[:4]}"))
  115. self.description = description or (
  116. schema.get("description", "")
  117. if isinstance(schema, dict)
  118. else getattr(schema, "__doc__", None) or ""
  119. )
  120. self.strict = strict
  121. if isinstance(schema, dict):
  122. self.schema_kind = "json_schema"
  123. self.json_schema = schema
  124. elif isinstance(schema, type) and issubclass(schema, BaseModel):
  125. self.schema_kind = "pydantic"
  126. self.json_schema = schema.model_json_schema()
  127. elif is_dataclass(schema):
  128. self.schema_kind = "dataclass"
  129. self.json_schema = TypeAdapter(schema).json_schema()
  130. elif is_typeddict(schema):
  131. self.schema_kind = "typeddict"
  132. self.json_schema = TypeAdapter(schema).json_schema()
  133. else:
  134. msg = (
  135. f"Unsupported schema type: {type(schema)}. "
  136. f"Supported types: Pydantic models, dataclasses, TypedDicts, and JSON schema dicts."
  137. )
  138. raise ValueError(msg)
  139. @dataclass(init=False)
  140. class ToolStrategy(Generic[SchemaT]):
  141. """Use a tool calling strategy for model responses."""
  142. schema: type[SchemaT]
  143. """Schema for the tool calls."""
  144. schema_specs: list[_SchemaSpec[SchemaT]]
  145. """Schema specs for the tool calls."""
  146. tool_message_content: str | None
  147. """The content of the tool message to be returned when the model calls
  148. an artificial structured output tool."""
  149. handle_errors: (
  150. bool | str | type[Exception] | tuple[type[Exception], ...] | Callable[[Exception], str]
  151. )
  152. """Error handling strategy for structured output via `ToolStrategy`.
  153. - `True`: Catch all errors with default error template
  154. - `str`: Catch all errors with this custom message
  155. - `type[Exception]`: Only catch this exception type with default message
  156. - `tuple[type[Exception], ...]`: Only catch these exception types with default
  157. message
  158. - `Callable[[Exception], str]`: Custom function that returns error message
  159. - `False`: No retry, let exceptions propagate
  160. """
  161. def __init__(
  162. self,
  163. schema: type[SchemaT],
  164. *,
  165. tool_message_content: str | None = None,
  166. handle_errors: bool
  167. | str
  168. | type[Exception]
  169. | tuple[type[Exception], ...]
  170. | Callable[[Exception], str] = True,
  171. ) -> None:
  172. """Initialize `ToolStrategy`.
  173. Initialize `ToolStrategy` with schemas, tool message content, and error handling
  174. strategy.
  175. """
  176. self.schema = schema
  177. self.tool_message_content = tool_message_content
  178. self.handle_errors = handle_errors
  179. def _iter_variants(schema: Any) -> Iterable[Any]:
  180. """Yield leaf variants from Union and JSON Schema oneOf."""
  181. if get_origin(schema) in (UnionType, Union):
  182. for arg in get_args(schema):
  183. yield from _iter_variants(arg)
  184. return
  185. if isinstance(schema, dict) and "oneOf" in schema:
  186. for sub in schema.get("oneOf", []):
  187. yield from _iter_variants(sub)
  188. return
  189. yield schema
  190. self.schema_specs = [_SchemaSpec(s) for s in _iter_variants(schema)]
  191. @dataclass(init=False)
  192. class ProviderStrategy(Generic[SchemaT]):
  193. """Use the model provider's native structured output method."""
  194. schema: type[SchemaT]
  195. """Schema for native mode."""
  196. schema_spec: _SchemaSpec[SchemaT]
  197. """Schema spec for native mode."""
  198. def __init__(
  199. self,
  200. schema: type[SchemaT],
  201. ) -> None:
  202. """Initialize ProviderStrategy with schema."""
  203. self.schema = schema
  204. self.schema_spec = _SchemaSpec(schema)
  205. def to_model_kwargs(self) -> dict[str, Any]:
  206. """Convert to kwargs to bind to a model to force structured output."""
  207. # OpenAI:
  208. # - see https://platform.openai.com/docs/guides/structured-outputs
  209. response_format = {
  210. "type": "json_schema",
  211. "json_schema": {
  212. "name": self.schema_spec.name,
  213. "schema": self.schema_spec.json_schema,
  214. },
  215. }
  216. return {"response_format": response_format}
  217. @dataclass
  218. class OutputToolBinding(Generic[SchemaT]):
  219. """Information for tracking structured output tool metadata.
  220. This contains all necessary information to handle structured responses
  221. generated via tool calls, including the original schema, its type classification,
  222. and the corresponding tool implementation used by the tools strategy.
  223. """
  224. schema: type[SchemaT]
  225. """The original schema provided for structured output
  226. (Pydantic model, dataclass, TypedDict, or JSON schema dict)."""
  227. schema_kind: SchemaKind
  228. """Classification of the schema type for proper response construction."""
  229. tool: BaseTool
  230. """LangChain tool instance created from the schema for model binding."""
  231. @classmethod
  232. def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self:
  233. """Create an `OutputToolBinding` instance from a `SchemaSpec`.
  234. Args:
  235. schema_spec: The `SchemaSpec` to convert
  236. Returns:
  237. An `OutputToolBinding` instance with the appropriate tool created
  238. """
  239. return cls(
  240. schema=schema_spec.schema,
  241. schema_kind=schema_spec.schema_kind,
  242. tool=StructuredTool(
  243. args_schema=schema_spec.json_schema,
  244. name=schema_spec.name,
  245. description=schema_spec.description,
  246. ),
  247. )
  248. def parse(self, tool_args: dict[str, Any]) -> SchemaT:
  249. """Parse tool arguments according to the schema.
  250. Args:
  251. tool_args: The arguments from the tool call
  252. Returns:
  253. The parsed response according to the schema type
  254. Raises:
  255. ValueError: If parsing fails
  256. """
  257. return _parse_with_schema(self.schema, self.schema_kind, tool_args)
  258. @dataclass
  259. class ProviderStrategyBinding(Generic[SchemaT]):
  260. """Information for tracking native structured output metadata.
  261. This contains all necessary information to handle structured responses
  262. generated via native provider output, including the original schema,
  263. its type classification, and parsing logic for provider-enforced JSON.
  264. """
  265. schema: type[SchemaT]
  266. """The original schema provided for structured output
  267. (Pydantic model, `dataclass`, `TypedDict`, or JSON schema dict)."""
  268. schema_kind: SchemaKind
  269. """Classification of the schema type for proper response construction."""
  270. @classmethod
  271. def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self:
  272. """Create a `ProviderStrategyBinding` instance from a `SchemaSpec`.
  273. Args:
  274. schema_spec: The `SchemaSpec` to convert
  275. Returns:
  276. A `ProviderStrategyBinding` instance for parsing native structured output
  277. """
  278. return cls(
  279. schema=schema_spec.schema,
  280. schema_kind=schema_spec.schema_kind,
  281. )
  282. def parse(self, response: AIMessage) -> SchemaT:
  283. """Parse `AIMessage` content according to the schema.
  284. Args:
  285. response: The `AIMessage` containing the structured output
  286. Returns:
  287. The parsed response according to the schema
  288. Raises:
  289. ValueError: If text extraction, JSON parsing or schema validation fails
  290. """
  291. # Extract text content from AIMessage and parse as JSON
  292. raw_text = self._extract_text_content_from_message(response)
  293. import json
  294. try:
  295. data = json.loads(raw_text)
  296. except Exception as e:
  297. schema_name = getattr(self.schema, "__name__", "response_format")
  298. msg = (
  299. f"Native structured output expected valid JSON for {schema_name}, "
  300. f"but parsing failed: {e}."
  301. )
  302. raise ValueError(msg) from e
  303. # Parse according to schema
  304. return _parse_with_schema(self.schema, self.schema_kind, data)
  305. def _extract_text_content_from_message(self, message: AIMessage) -> str:
  306. """Extract text content from an AIMessage.
  307. Args:
  308. message: The AI message to extract text from
  309. Returns:
  310. The extracted text content
  311. """
  312. content = message.content
  313. if isinstance(content, str):
  314. return content
  315. if isinstance(content, list):
  316. parts: list[str] = []
  317. for c in content:
  318. if isinstance(c, dict):
  319. if c.get("type") == "text" and "text" in c:
  320. parts.append(str(c["text"]))
  321. elif "content" in c and isinstance(c["content"], str):
  322. parts.append(c["content"])
  323. else:
  324. parts.append(str(c))
  325. return "".join(parts)
  326. return str(content)
  327. class AutoStrategy(Generic[SchemaT]):
  328. """Automatically select the best strategy for structured output."""
  329. schema: type[SchemaT]
  330. """Schema for automatic mode."""
  331. def __init__(
  332. self,
  333. schema: type[SchemaT],
  334. ) -> None:
  335. """Initialize AutoStrategy with schema."""
  336. self.schema = schema
  337. ResponseFormat = ToolStrategy[SchemaT] | ProviderStrategy[SchemaT] | AutoStrategy[SchemaT]