"""Types for setting agent response formats.""" from __future__ import annotations import uuid from dataclasses import dataclass, is_dataclass from types import UnionType from typing import ( TYPE_CHECKING, Any, Generic, Literal, TypeVar, Union, get_args, get_origin, ) from langchain_core.tools import BaseTool, StructuredTool from pydantic import BaseModel, TypeAdapter from typing_extensions import Self, is_typeddict if TYPE_CHECKING: from collections.abc import Callable, Iterable from langchain_core.messages import AIMessage # Supported schema types: Pydantic models, dataclasses, TypedDict, JSON schema dicts SchemaT = TypeVar("SchemaT") SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"] class StructuredOutputError(Exception): """Base class for structured output errors.""" ai_message: AIMessage class MultipleStructuredOutputsError(StructuredOutputError): """Raised when model returns multiple structured output tool calls when only one is expected.""" def __init__(self, tool_names: list[str], ai_message: AIMessage) -> None: """Initialize `MultipleStructuredOutputsError`. Args: tool_names: The names of the tools called for structured output. ai_message: The AI message that contained the invalid multiple tool calls. """ self.tool_names = tool_names self.ai_message = ai_message super().__init__( "Model incorrectly returned multiple structured responses " f"({', '.join(tool_names)}) when only one is expected." ) class StructuredOutputValidationError(StructuredOutputError): """Raised when structured output tool call arguments fail to parse according to the schema.""" def __init__(self, tool_name: str, source: Exception, ai_message: AIMessage) -> None: """Initialize `StructuredOutputValidationError`. Args: tool_name: The name of the tool that failed. source: The exception that occurred. ai_message: The AI message that contained the invalid structured output. """ self.tool_name = tool_name self.source = source self.ai_message = ai_message super().__init__(f"Failed to parse structured output for tool '{tool_name}': {source}.") def _parse_with_schema( schema: type[SchemaT] | dict, schema_kind: SchemaKind, data: dict[str, Any] ) -> Any: """Parse data using for any supported schema type. Args: schema: The schema type (Pydantic model, `dataclass`, or `TypedDict`) schema_kind: One of `"pydantic"`, `"dataclass"`, `"typeddict"`, or `"json_schema"` data: The data to parse Returns: The parsed instance according to the schema type Raises: ValueError: If parsing fails """ if schema_kind == "json_schema": return data try: adapter: TypeAdapter[SchemaT] = TypeAdapter(schema) return adapter.validate_python(data) except Exception as e: schema_name = getattr(schema, "__name__", str(schema)) msg = f"Failed to parse data to {schema_name}: {e}" raise ValueError(msg) from e @dataclass(init=False) class _SchemaSpec(Generic[SchemaT]): """Describes a structured output schema.""" schema: type[SchemaT] """The schema for the response, can be a Pydantic model, `dataclass`, `TypedDict`, or JSON schema dict.""" name: str """Name of the schema, used for tool calling. If not provided, the name will be the model name or `"response_format"` if it's a JSON schema. """ description: str """Custom description of the schema. If not provided, provided will use the model's docstring. """ schema_kind: SchemaKind """The kind of schema.""" json_schema: dict[str, Any] """JSON schema associated with the schema.""" strict: bool = False """Whether to enforce strict validation of the schema.""" def __init__( self, schema: type[SchemaT], *, name: str | None = None, description: str | None = None, strict: bool = False, ) -> None: """Initialize SchemaSpec with schema and optional parameters.""" self.schema = schema if name: self.name = name elif isinstance(schema, dict): self.name = str(schema.get("title", f"response_format_{str(uuid.uuid4())[:4]}")) else: self.name = str(getattr(schema, "__name__", f"response_format_{str(uuid.uuid4())[:4]}")) self.description = description or ( schema.get("description", "") if isinstance(schema, dict) else getattr(schema, "__doc__", None) or "" ) self.strict = strict if isinstance(schema, dict): self.schema_kind = "json_schema" self.json_schema = schema elif isinstance(schema, type) and issubclass(schema, BaseModel): self.schema_kind = "pydantic" self.json_schema = schema.model_json_schema() elif is_dataclass(schema): self.schema_kind = "dataclass" self.json_schema = TypeAdapter(schema).json_schema() elif is_typeddict(schema): self.schema_kind = "typeddict" self.json_schema = TypeAdapter(schema).json_schema() else: msg = ( f"Unsupported schema type: {type(schema)}. " f"Supported types: Pydantic models, dataclasses, TypedDicts, and JSON schema dicts." ) raise ValueError(msg) @dataclass(init=False) class ToolStrategy(Generic[SchemaT]): """Use a tool calling strategy for model responses.""" schema: type[SchemaT] """Schema for the tool calls.""" schema_specs: list[_SchemaSpec[SchemaT]] """Schema specs for the tool calls.""" tool_message_content: str | None """The content of the tool message to be returned when the model calls an artificial structured output tool.""" handle_errors: ( bool | str | type[Exception] | tuple[type[Exception], ...] | Callable[[Exception], str] ) """Error handling strategy for structured output via `ToolStrategy`. - `True`: Catch all errors with default error template - `str`: Catch all errors with this custom message - `type[Exception]`: Only catch this exception type with default message - `tuple[type[Exception], ...]`: Only catch these exception types with default message - `Callable[[Exception], str]`: Custom function that returns error message - `False`: No retry, let exceptions propagate """ def __init__( self, schema: type[SchemaT], *, tool_message_content: str | None = None, handle_errors: bool | str | type[Exception] | tuple[type[Exception], ...] | Callable[[Exception], str] = True, ) -> None: """Initialize `ToolStrategy`. Initialize `ToolStrategy` with schemas, tool message content, and error handling strategy. """ self.schema = schema self.tool_message_content = tool_message_content self.handle_errors = handle_errors def _iter_variants(schema: Any) -> Iterable[Any]: """Yield leaf variants from Union and JSON Schema oneOf.""" if get_origin(schema) in (UnionType, Union): for arg in get_args(schema): yield from _iter_variants(arg) return if isinstance(schema, dict) and "oneOf" in schema: for sub in schema.get("oneOf", []): yield from _iter_variants(sub) return yield schema self.schema_specs = [_SchemaSpec(s) for s in _iter_variants(schema)] @dataclass(init=False) class ProviderStrategy(Generic[SchemaT]): """Use the model provider's native structured output method.""" schema: type[SchemaT] """Schema for native mode.""" schema_spec: _SchemaSpec[SchemaT] """Schema spec for native mode.""" def __init__( self, schema: type[SchemaT], ) -> None: """Initialize ProviderStrategy with schema.""" self.schema = schema self.schema_spec = _SchemaSpec(schema) def to_model_kwargs(self) -> dict[str, Any]: """Convert to kwargs to bind to a model to force structured output.""" # OpenAI: # - see https://platform.openai.com/docs/guides/structured-outputs response_format = { "type": "json_schema", "json_schema": { "name": self.schema_spec.name, "schema": self.schema_spec.json_schema, }, } return {"response_format": response_format} @dataclass class OutputToolBinding(Generic[SchemaT]): """Information for tracking structured output tool metadata. This contains all necessary information to handle structured responses generated via tool calls, including the original schema, its type classification, and the corresponding tool implementation used by the tools strategy. """ schema: type[SchemaT] """The original schema provided for structured output (Pydantic model, dataclass, TypedDict, or JSON schema dict).""" schema_kind: SchemaKind """Classification of the schema type for proper response construction.""" tool: BaseTool """LangChain tool instance created from the schema for model binding.""" @classmethod def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self: """Create an `OutputToolBinding` instance from a `SchemaSpec`. Args: schema_spec: The `SchemaSpec` to convert Returns: An `OutputToolBinding` instance with the appropriate tool created """ return cls( schema=schema_spec.schema, schema_kind=schema_spec.schema_kind, tool=StructuredTool( args_schema=schema_spec.json_schema, name=schema_spec.name, description=schema_spec.description, ), ) def parse(self, tool_args: dict[str, Any]) -> SchemaT: """Parse tool arguments according to the schema. Args: tool_args: The arguments from the tool call Returns: The parsed response according to the schema type Raises: ValueError: If parsing fails """ return _parse_with_schema(self.schema, self.schema_kind, tool_args) @dataclass class ProviderStrategyBinding(Generic[SchemaT]): """Information for tracking native structured output metadata. This contains all necessary information to handle structured responses generated via native provider output, including the original schema, its type classification, and parsing logic for provider-enforced JSON. """ schema: type[SchemaT] """The original schema provided for structured output (Pydantic model, `dataclass`, `TypedDict`, or JSON schema dict).""" schema_kind: SchemaKind """Classification of the schema type for proper response construction.""" @classmethod def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self: """Create a `ProviderStrategyBinding` instance from a `SchemaSpec`. Args: schema_spec: The `SchemaSpec` to convert Returns: A `ProviderStrategyBinding` instance for parsing native structured output """ return cls( schema=schema_spec.schema, schema_kind=schema_spec.schema_kind, ) def parse(self, response: AIMessage) -> SchemaT: """Parse `AIMessage` content according to the schema. Args: response: The `AIMessage` containing the structured output Returns: The parsed response according to the schema Raises: ValueError: If text extraction, JSON parsing or schema validation fails """ # Extract text content from AIMessage and parse as JSON raw_text = self._extract_text_content_from_message(response) import json try: data = json.loads(raw_text) except Exception as e: schema_name = getattr(self.schema, "__name__", "response_format") msg = ( f"Native structured output expected valid JSON for {schema_name}, " f"but parsing failed: {e}." ) raise ValueError(msg) from e # Parse according to schema return _parse_with_schema(self.schema, self.schema_kind, data) def _extract_text_content_from_message(self, message: AIMessage) -> str: """Extract text content from an AIMessage. Args: message: The AI message to extract text from Returns: The extracted text content """ content = message.content if isinstance(content, str): return content if isinstance(content, list): parts: list[str] = [] for c in content: if isinstance(c, dict): if c.get("type") == "text" and "text" in c: parts.append(str(c["text"])) elif "content" in c and isinstance(c["content"], str): parts.append(c["content"]) else: parts.append(str(c)) return "".join(parts) return str(content) class AutoStrategy(Generic[SchemaT]): """Automatically select the best strategy for structured output.""" schema: type[SchemaT] """Schema for automatic mode.""" def __init__( self, schema: type[SchemaT], ) -> None: """Initialize AutoStrategy with schema.""" self.schema = schema ResponseFormat = ToolStrategy[SchemaT] | ProviderStrategy[SchemaT] | AutoStrategy[SchemaT]