| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432 |
- """Types for setting agent response formats."""
- from __future__ import annotations
- import uuid
- from dataclasses import dataclass, is_dataclass
- from types import UnionType
- from typing import (
- TYPE_CHECKING,
- Any,
- Generic,
- Literal,
- TypeVar,
- Union,
- get_args,
- get_origin,
- )
- from langchain_core.tools import BaseTool, StructuredTool
- from pydantic import BaseModel, TypeAdapter
- from typing_extensions import Self, is_typeddict
- if TYPE_CHECKING:
- from collections.abc import Callable, Iterable
- from langchain_core.messages import AIMessage
- # Supported schema types: Pydantic models, dataclasses, TypedDict, JSON schema dicts
- SchemaT = TypeVar("SchemaT")
- SchemaKind = Literal["pydantic", "dataclass", "typeddict", "json_schema"]
- class StructuredOutputError(Exception):
- """Base class for structured output errors."""
- ai_message: AIMessage
- class MultipleStructuredOutputsError(StructuredOutputError):
- """Raised when model returns multiple structured output tool calls when only one is expected."""
- def __init__(self, tool_names: list[str], ai_message: AIMessage) -> None:
- """Initialize `MultipleStructuredOutputsError`.
- Args:
- tool_names: The names of the tools called for structured output.
- ai_message: The AI message that contained the invalid multiple tool calls.
- """
- self.tool_names = tool_names
- self.ai_message = ai_message
- super().__init__(
- "Model incorrectly returned multiple structured responses "
- f"({', '.join(tool_names)}) when only one is expected."
- )
- class StructuredOutputValidationError(StructuredOutputError):
- """Raised when structured output tool call arguments fail to parse according to the schema."""
- def __init__(self, tool_name: str, source: Exception, ai_message: AIMessage) -> None:
- """Initialize `StructuredOutputValidationError`.
- Args:
- tool_name: The name of the tool that failed.
- source: The exception that occurred.
- ai_message: The AI message that contained the invalid structured output.
- """
- self.tool_name = tool_name
- self.source = source
- self.ai_message = ai_message
- super().__init__(f"Failed to parse structured output for tool '{tool_name}': {source}.")
- def _parse_with_schema(
- schema: type[SchemaT] | dict, schema_kind: SchemaKind, data: dict[str, Any]
- ) -> Any:
- """Parse data using for any supported schema type.
- Args:
- schema: The schema type (Pydantic model, `dataclass`, or `TypedDict`)
- schema_kind: One of `"pydantic"`, `"dataclass"`, `"typeddict"`, or
- `"json_schema"`
- data: The data to parse
- Returns:
- The parsed instance according to the schema type
- Raises:
- ValueError: If parsing fails
- """
- if schema_kind == "json_schema":
- return data
- try:
- adapter: TypeAdapter[SchemaT] = TypeAdapter(schema)
- return adapter.validate_python(data)
- except Exception as e:
- schema_name = getattr(schema, "__name__", str(schema))
- msg = f"Failed to parse data to {schema_name}: {e}"
- raise ValueError(msg) from e
- @dataclass(init=False)
- class _SchemaSpec(Generic[SchemaT]):
- """Describes a structured output schema."""
- schema: type[SchemaT]
- """The schema for the response, can be a Pydantic model, `dataclass`, `TypedDict`,
- or JSON schema dict."""
- name: str
- """Name of the schema, used for tool calling.
- If not provided, the name will be the model name or `"response_format"` if it's a
- JSON schema.
- """
- description: str
- """Custom description of the schema.
- If not provided, provided will use the model's docstring.
- """
- schema_kind: SchemaKind
- """The kind of schema."""
- json_schema: dict[str, Any]
- """JSON schema associated with the schema."""
- strict: bool = False
- """Whether to enforce strict validation of the schema."""
- def __init__(
- self,
- schema: type[SchemaT],
- *,
- name: str | None = None,
- description: str | None = None,
- strict: bool = False,
- ) -> None:
- """Initialize SchemaSpec with schema and optional parameters."""
- self.schema = schema
- if name:
- self.name = name
- elif isinstance(schema, dict):
- self.name = str(schema.get("title", f"response_format_{str(uuid.uuid4())[:4]}"))
- else:
- self.name = str(getattr(schema, "__name__", f"response_format_{str(uuid.uuid4())[:4]}"))
- self.description = description or (
- schema.get("description", "")
- if isinstance(schema, dict)
- else getattr(schema, "__doc__", None) or ""
- )
- self.strict = strict
- if isinstance(schema, dict):
- self.schema_kind = "json_schema"
- self.json_schema = schema
- elif isinstance(schema, type) and issubclass(schema, BaseModel):
- self.schema_kind = "pydantic"
- self.json_schema = schema.model_json_schema()
- elif is_dataclass(schema):
- self.schema_kind = "dataclass"
- self.json_schema = TypeAdapter(schema).json_schema()
- elif is_typeddict(schema):
- self.schema_kind = "typeddict"
- self.json_schema = TypeAdapter(schema).json_schema()
- else:
- msg = (
- f"Unsupported schema type: {type(schema)}. "
- f"Supported types: Pydantic models, dataclasses, TypedDicts, and JSON schema dicts."
- )
- raise ValueError(msg)
- @dataclass(init=False)
- class ToolStrategy(Generic[SchemaT]):
- """Use a tool calling strategy for model responses."""
- schema: type[SchemaT]
- """Schema for the tool calls."""
- schema_specs: list[_SchemaSpec[SchemaT]]
- """Schema specs for the tool calls."""
- tool_message_content: str | None
- """The content of the tool message to be returned when the model calls
- an artificial structured output tool."""
- handle_errors: (
- bool | str | type[Exception] | tuple[type[Exception], ...] | Callable[[Exception], str]
- )
- """Error handling strategy for structured output via `ToolStrategy`.
- - `True`: Catch all errors with default error template
- - `str`: Catch all errors with this custom message
- - `type[Exception]`: Only catch this exception type with default message
- - `tuple[type[Exception], ...]`: Only catch these exception types with default
- message
- - `Callable[[Exception], str]`: Custom function that returns error message
- - `False`: No retry, let exceptions propagate
- """
- def __init__(
- self,
- schema: type[SchemaT],
- *,
- tool_message_content: str | None = None,
- handle_errors: bool
- | str
- | type[Exception]
- | tuple[type[Exception], ...]
- | Callable[[Exception], str] = True,
- ) -> None:
- """Initialize `ToolStrategy`.
- Initialize `ToolStrategy` with schemas, tool message content, and error handling
- strategy.
- """
- self.schema = schema
- self.tool_message_content = tool_message_content
- self.handle_errors = handle_errors
- def _iter_variants(schema: Any) -> Iterable[Any]:
- """Yield leaf variants from Union and JSON Schema oneOf."""
- if get_origin(schema) in (UnionType, Union):
- for arg in get_args(schema):
- yield from _iter_variants(arg)
- return
- if isinstance(schema, dict) and "oneOf" in schema:
- for sub in schema.get("oneOf", []):
- yield from _iter_variants(sub)
- return
- yield schema
- self.schema_specs = [_SchemaSpec(s) for s in _iter_variants(schema)]
- @dataclass(init=False)
- class ProviderStrategy(Generic[SchemaT]):
- """Use the model provider's native structured output method."""
- schema: type[SchemaT]
- """Schema for native mode."""
- schema_spec: _SchemaSpec[SchemaT]
- """Schema spec for native mode."""
- def __init__(
- self,
- schema: type[SchemaT],
- ) -> None:
- """Initialize ProviderStrategy with schema."""
- self.schema = schema
- self.schema_spec = _SchemaSpec(schema)
- def to_model_kwargs(self) -> dict[str, Any]:
- """Convert to kwargs to bind to a model to force structured output."""
- # OpenAI:
- # - see https://platform.openai.com/docs/guides/structured-outputs
- response_format = {
- "type": "json_schema",
- "json_schema": {
- "name": self.schema_spec.name,
- "schema": self.schema_spec.json_schema,
- },
- }
- return {"response_format": response_format}
- @dataclass
- class OutputToolBinding(Generic[SchemaT]):
- """Information for tracking structured output tool metadata.
- This contains all necessary information to handle structured responses
- generated via tool calls, including the original schema, its type classification,
- and the corresponding tool implementation used by the tools strategy.
- """
- schema: type[SchemaT]
- """The original schema provided for structured output
- (Pydantic model, dataclass, TypedDict, or JSON schema dict)."""
- schema_kind: SchemaKind
- """Classification of the schema type for proper response construction."""
- tool: BaseTool
- """LangChain tool instance created from the schema for model binding."""
- @classmethod
- def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self:
- """Create an `OutputToolBinding` instance from a `SchemaSpec`.
- Args:
- schema_spec: The `SchemaSpec` to convert
- Returns:
- An `OutputToolBinding` instance with the appropriate tool created
- """
- return cls(
- schema=schema_spec.schema,
- schema_kind=schema_spec.schema_kind,
- tool=StructuredTool(
- args_schema=schema_spec.json_schema,
- name=schema_spec.name,
- description=schema_spec.description,
- ),
- )
- def parse(self, tool_args: dict[str, Any]) -> SchemaT:
- """Parse tool arguments according to the schema.
- Args:
- tool_args: The arguments from the tool call
- Returns:
- The parsed response according to the schema type
- Raises:
- ValueError: If parsing fails
- """
- return _parse_with_schema(self.schema, self.schema_kind, tool_args)
- @dataclass
- class ProviderStrategyBinding(Generic[SchemaT]):
- """Information for tracking native structured output metadata.
- This contains all necessary information to handle structured responses
- generated via native provider output, including the original schema,
- its type classification, and parsing logic for provider-enforced JSON.
- """
- schema: type[SchemaT]
- """The original schema provided for structured output
- (Pydantic model, `dataclass`, `TypedDict`, or JSON schema dict)."""
- schema_kind: SchemaKind
- """Classification of the schema type for proper response construction."""
- @classmethod
- def from_schema_spec(cls, schema_spec: _SchemaSpec[SchemaT]) -> Self:
- """Create a `ProviderStrategyBinding` instance from a `SchemaSpec`.
- Args:
- schema_spec: The `SchemaSpec` to convert
- Returns:
- A `ProviderStrategyBinding` instance for parsing native structured output
- """
- return cls(
- schema=schema_spec.schema,
- schema_kind=schema_spec.schema_kind,
- )
- def parse(self, response: AIMessage) -> SchemaT:
- """Parse `AIMessage` content according to the schema.
- Args:
- response: The `AIMessage` containing the structured output
- Returns:
- The parsed response according to the schema
- Raises:
- ValueError: If text extraction, JSON parsing or schema validation fails
- """
- # Extract text content from AIMessage and parse as JSON
- raw_text = self._extract_text_content_from_message(response)
- import json
- try:
- data = json.loads(raw_text)
- except Exception as e:
- schema_name = getattr(self.schema, "__name__", "response_format")
- msg = (
- f"Native structured output expected valid JSON for {schema_name}, "
- f"but parsing failed: {e}."
- )
- raise ValueError(msg) from e
- # Parse according to schema
- return _parse_with_schema(self.schema, self.schema_kind, data)
- def _extract_text_content_from_message(self, message: AIMessage) -> str:
- """Extract text content from an AIMessage.
- Args:
- message: The AI message to extract text from
- Returns:
- The extracted text content
- """
- content = message.content
- if isinstance(content, str):
- return content
- if isinstance(content, list):
- parts: list[str] = []
- for c in content:
- if isinstance(c, dict):
- if c.get("type") == "text" and "text" in c:
- parts.append(str(c["text"]))
- elif "content" in c and isinstance(c["content"], str):
- parts.append(c["content"])
- else:
- parts.append(str(c))
- return "".join(parts)
- return str(content)
- class AutoStrategy(Generic[SchemaT]):
- """Automatically select the best strategy for structured output."""
- schema: type[SchemaT]
- """Schema for automatic mode."""
- def __init__(
- self,
- schema: type[SchemaT],
- ) -> None:
- """Initialize AutoStrategy with schema."""
- self.schema = schema
- ResponseFormat = ToolStrategy[SchemaT] | ProviderStrategy[SchemaT] | AutoStrategy[SchemaT]
|