| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937 |
- """Factory functions for chat models."""
- from __future__ import annotations
- import warnings
- from importlib import util
- from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast, overload
- from langchain_core.language_models import BaseChatModel, LanguageModelInput
- from langchain_core.messages import AIMessage, AnyMessage
- from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
- from typing_extensions import override
- if TYPE_CHECKING:
- from collections.abc import AsyncIterator, Callable, Iterator, Sequence
- from langchain_core.runnables.schema import StreamEvent
- from langchain_core.tools import BaseTool
- from langchain_core.tracers import RunLog, RunLogPatch
- from pydantic import BaseModel
- @overload
- def init_chat_model(
- model: str,
- *,
- model_provider: str | None = None,
- configurable_fields: None = None,
- config_prefix: str | None = None,
- **kwargs: Any,
- ) -> BaseChatModel: ...
- @overload
- def init_chat_model(
- model: None = None,
- *,
- model_provider: str | None = None,
- configurable_fields: None = None,
- config_prefix: str | None = None,
- **kwargs: Any,
- ) -> _ConfigurableModel: ...
- @overload
- def init_chat_model(
- model: str | None = None,
- *,
- model_provider: str | None = None,
- configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = ...,
- config_prefix: str | None = None,
- **kwargs: Any,
- ) -> _ConfigurableModel: ...
- # FOR CONTRIBUTORS: If adding support for a new provider, please append the provider
- # name to the supported list in the docstring below. Do *not* change the order of the
- # existing providers.
- def init_chat_model(
- model: str | None = None,
- *,
- model_provider: str | None = None,
- configurable_fields: Literal["any"] | list[str] | tuple[str, ...] | None = None,
- config_prefix: str | None = None,
- **kwargs: Any,
- ) -> BaseChatModel | _ConfigurableModel:
- """Initialize a chat model from any supported provider using a unified interface.
- **Two main use cases:**
- 1. **Fixed model** – specify the model upfront and get a ready-to-use chat model.
- 2. **Configurable model** – choose to specify parameters (including model name) at
- runtime via `config`. Makes it easy to switch between models/providers without
- changing your code
- !!! note
- Requires the integration package for the chosen model provider to be installed.
- See the `model_provider` parameter below for specific package names
- (e.g., `pip install langchain-openai`).
- Refer to the [provider integration's API reference](https://docs.langchain.com/oss/python/integrations/providers)
- for supported model parameters to use as `**kwargs`.
- Args:
- model: The name or ID of the model, e.g. `'o3-mini'`, `'claude-sonnet-4-5-20250929'`.
- You can also specify model and model provider in a single argument using
- `'{model_provider}:{model}'` format, e.g. `'openai:o1'`.
- Will attempt to infer `model_provider` from model if not specified.
- The following providers will be inferred based on these model prefixes:
- - `gpt-...` | `o1...` | `o3...` -> `openai`
- - `claude...` -> `anthropic`
- - `amazon...` -> `bedrock`
- - `gemini...` -> `google_vertexai`
- - `command...` -> `cohere`
- - `accounts/fireworks...` -> `fireworks`
- - `mistral...` -> `mistralai`
- - `deepseek...` -> `deepseek`
- - `grok...` -> `xai`
- - `sonar...` -> `perplexity`
- - `solar...` -> `upstage`
- model_provider: The model provider if not specified as part of the model arg
- (see above).
- Supported `model_provider` values and the corresponding integration package
- are:
- - `openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
- - `anthropic` -> [`langchain-anthropic`](https://docs.langchain.com/oss/python/integrations/providers/anthropic)
- - `azure_openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
- - `azure_ai` -> [`langchain-azure-ai`](https://docs.langchain.com/oss/python/integrations/providers/microsoft)
- - `google_vertexai` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- - `google_genai` -> [`langchain-google-genai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- - `bedrock` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
- - `bedrock_converse` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
- - `cohere` -> [`langchain-cohere`](https://docs.langchain.com/oss/python/integrations/providers/cohere)
- - `fireworks` -> [`langchain-fireworks`](https://docs.langchain.com/oss/python/integrations/providers/fireworks)
- - `together` -> [`langchain-together`](https://docs.langchain.com/oss/python/integrations/providers/together)
- - `mistralai` -> [`langchain-mistralai`](https://docs.langchain.com/oss/python/integrations/providers/mistralai)
- - `huggingface` -> [`langchain-huggingface`](https://docs.langchain.com/oss/python/integrations/providers/huggingface)
- - `groq` -> [`langchain-groq`](https://docs.langchain.com/oss/python/integrations/providers/groq)
- - `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama)
- - `google_anthropic_vertex` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- - `deepseek` -> [`langchain-deepseek`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
- - `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/ibm)
- - `nvidia` -> [`langchain-nvidia-ai-endpoints`](https://docs.langchain.com/oss/python/integrations/providers/nvidia)
- - `xai` -> [`langchain-xai`](https://docs.langchain.com/oss/python/integrations/providers/xai)
- - `perplexity` -> [`langchain-perplexity`](https://docs.langchain.com/oss/python/integrations/providers/perplexity)
- - `upstage` -> [`langchain-upstage`](https://docs.langchain.com/oss/python/integrations/providers/upstage)
- configurable_fields: Which model parameters are configurable at runtime:
- - `None`: No configurable fields (i.e., a fixed model).
- - `'any'`: All fields are configurable. **See security note below.**
- - `list[str] | Tuple[str, ...]`: Specified fields are configurable.
- Fields are assumed to have `config_prefix` stripped if a `config_prefix` is
- specified.
- If `model` is specified, then defaults to `None`.
- If `model` is not specified, then defaults to `("model", "model_provider")`.
- !!! warning "Security note"
- Setting `configurable_fields="any"` means fields like `api_key`,
- `base_url`, etc., can be altered at runtime, potentially redirecting
- model requests to a different service/user.
- Make sure that if you're accepting untrusted configurations that you
- enumerate the `configurable_fields=(...)` explicitly.
- config_prefix: Optional prefix for configuration keys.
- Useful when you have multiple configurable models in the same application.
- If `'config_prefix'` is a non-empty string then `model` will be configurable
- at runtime via the `config["configurable"]["{config_prefix}_{param}"]` keys.
- See examples below.
- If `'config_prefix'` is an empty string then model will be configurable via
- `config["configurable"]["{param}"]`.
- **kwargs: Additional model-specific keyword args to pass to the underlying
- chat model's `__init__` method. Common parameters include:
- - `temperature`: Model temperature for controlling randomness.
- - `max_tokens`: Maximum number of output tokens.
- - `timeout`: Maximum time (in seconds) to wait for a response.
- - `max_retries`: Maximum number of retry attempts for failed requests.
- - `base_url`: Custom API endpoint URL.
- - `rate_limiter`: A
- [`BaseRateLimiter`][langchain_core.rate_limiters.BaseRateLimiter]
- instance to control request rate.
- Refer to the specific model provider's
- [integration reference](https://reference.langchain.com/python/integrations/)
- for all available parameters.
- Returns:
- A `BaseChatModel` corresponding to the `model_name` and `model_provider`
- specified if configurability is inferred to be `False`. If configurable, a
- chat model emulator that initializes the underlying model at runtime once a
- config is passed in.
- Raises:
- ValueError: If `model_provider` cannot be inferred or isn't supported.
- ImportError: If the model provider integration package is not installed.
- ???+ example "Initialize a non-configurable model"
- ```python
- # pip install langchain langchain-openai langchain-anthropic langchain-google-vertexai
- from langchain.chat_models import init_chat_model
- o3_mini = init_chat_model("openai:o3-mini", temperature=0)
- claude_sonnet = init_chat_model("anthropic:claude-sonnet-4-5-20250929", temperature=0)
- gemini_2-5_flash = init_chat_model("google_vertexai:gemini-2.5-flash", temperature=0)
- o3_mini.invoke("what's your name")
- claude_sonnet.invoke("what's your name")
- gemini_2-5_flash.invoke("what's your name")
- ```
- ??? example "Partially configurable model with no default"
- ```python
- # pip install langchain langchain-openai langchain-anthropic
- from langchain.chat_models import init_chat_model
- # (We don't need to specify configurable=True if a model isn't specified.)
- configurable_model = init_chat_model(temperature=0)
- configurable_model.invoke("what's your name", config={"configurable": {"model": "gpt-4o"}})
- # Use GPT-4o to generate the response
- configurable_model.invoke(
- "what's your name",
- config={"configurable": {"model": "claude-sonnet-4-5-20250929"}},
- )
- ```
- ??? example "Fully configurable model with a default"
- ```python
- # pip install langchain langchain-openai langchain-anthropic
- from langchain.chat_models import init_chat_model
- configurable_model_with_default = init_chat_model(
- "openai:gpt-4o",
- configurable_fields="any", # This allows us to configure other params like temperature, max_tokens, etc at runtime.
- config_prefix="foo",
- temperature=0,
- )
- configurable_model_with_default.invoke("what's your name")
- # GPT-4o response with temperature 0 (as set in default)
- configurable_model_with_default.invoke(
- "what's your name",
- config={
- "configurable": {
- "foo_model": "anthropic:claude-sonnet-4-5-20250929",
- "foo_temperature": 0.6,
- }
- },
- )
- # Override default to use Sonnet 4.5 with temperature 0.6 to generate response
- ```
- ??? example "Bind tools to a configurable model"
- You can call any chat model declarative methods on a configurable model in the
- same way that you would with a normal model:
- ```python
- # pip install langchain langchain-openai langchain-anthropic
- from langchain.chat_models import init_chat_model
- from pydantic import BaseModel, Field
- class GetWeather(BaseModel):
- '''Get the current weather in a given location'''
- location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
- class GetPopulation(BaseModel):
- '''Get the current population in a given location'''
- location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
- configurable_model = init_chat_model(
- "gpt-4o", configurable_fields=("model", "model_provider"), temperature=0
- )
- configurable_model_with_tools = configurable_model.bind_tools(
- [
- GetWeather,
- GetPopulation,
- ]
- )
- configurable_model_with_tools.invoke(
- "Which city is hotter today and which is bigger: LA or NY?"
- )
- # Use GPT-4o
- configurable_model_with_tools.invoke(
- "Which city is hotter today and which is bigger: LA or NY?",
- config={"configurable": {"model": "claude-sonnet-4-5-20250929"}},
- )
- # Use Sonnet 4.5
- ```
- """ # noqa: E501
- if not model and not configurable_fields:
- configurable_fields = ("model", "model_provider")
- config_prefix = config_prefix or ""
- if config_prefix and not configurable_fields:
- warnings.warn(
- f"{config_prefix=} has been set but no fields are configurable. Set "
- f"`configurable_fields=(...)` to specify the model params that are "
- f"configurable.",
- stacklevel=2,
- )
- if not configurable_fields:
- return _init_chat_model_helper(
- cast("str", model),
- model_provider=model_provider,
- **kwargs,
- )
- if model:
- kwargs["model"] = model
- if model_provider:
- kwargs["model_provider"] = model_provider
- return _ConfigurableModel(
- default_config=kwargs,
- config_prefix=config_prefix,
- configurable_fields=configurable_fields,
- )
- def _init_chat_model_helper(
- model: str,
- *,
- model_provider: str | None = None,
- **kwargs: Any,
- ) -> BaseChatModel:
- model, model_provider = _parse_model(model, model_provider)
- if model_provider == "openai":
- _check_pkg("langchain_openai")
- from langchain_openai import ChatOpenAI
- return ChatOpenAI(model=model, **kwargs)
- if model_provider == "anthropic":
- _check_pkg("langchain_anthropic")
- from langchain_anthropic import ChatAnthropic
- return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
- if model_provider == "azure_openai":
- _check_pkg("langchain_openai")
- from langchain_openai import AzureChatOpenAI
- return AzureChatOpenAI(model=model, **kwargs)
- if model_provider == "azure_ai":
- _check_pkg("langchain_azure_ai")
- from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
- return AzureAIChatCompletionsModel(model=model, **kwargs)
- if model_provider == "cohere":
- _check_pkg("langchain_cohere")
- from langchain_cohere import ChatCohere
- return ChatCohere(model=model, **kwargs)
- if model_provider == "google_vertexai":
- _check_pkg("langchain_google_vertexai")
- from langchain_google_vertexai import ChatVertexAI
- return ChatVertexAI(model=model, **kwargs)
- if model_provider == "google_genai":
- _check_pkg("langchain_google_genai")
- from langchain_google_genai import ChatGoogleGenerativeAI
- return ChatGoogleGenerativeAI(model=model, **kwargs)
- if model_provider == "fireworks":
- _check_pkg("langchain_fireworks")
- from langchain_fireworks import ChatFireworks
- return ChatFireworks(model=model, **kwargs)
- if model_provider == "ollama":
- try:
- _check_pkg("langchain_ollama")
- from langchain_ollama import ChatOllama
- except ImportError:
- # For backwards compatibility
- try:
- _check_pkg("langchain_community")
- from langchain_community.chat_models import ChatOllama
- except ImportError:
- # If both langchain-ollama and langchain-community aren't available,
- # raise an error related to langchain-ollama
- _check_pkg("langchain_ollama")
- return ChatOllama(model=model, **kwargs)
- if model_provider == "together":
- _check_pkg("langchain_together")
- from langchain_together import ChatTogether
- return ChatTogether(model=model, **kwargs)
- if model_provider == "mistralai":
- _check_pkg("langchain_mistralai")
- from langchain_mistralai import ChatMistralAI
- return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
- if model_provider == "huggingface":
- _check_pkg("langchain_huggingface")
- from langchain_huggingface import ChatHuggingFace
- return ChatHuggingFace(model_id=model, **kwargs)
- if model_provider == "groq":
- _check_pkg("langchain_groq")
- from langchain_groq import ChatGroq
- return ChatGroq(model=model, **kwargs)
- if model_provider == "bedrock":
- _check_pkg("langchain_aws")
- from langchain_aws import ChatBedrock
- return ChatBedrock(model_id=model, **kwargs)
- if model_provider == "bedrock_converse":
- _check_pkg("langchain_aws")
- from langchain_aws import ChatBedrockConverse
- return ChatBedrockConverse(model=model, **kwargs)
- if model_provider == "google_anthropic_vertex":
- _check_pkg("langchain_google_vertexai")
- from langchain_google_vertexai.model_garden import ChatAnthropicVertex
- return ChatAnthropicVertex(model=model, **kwargs)
- if model_provider == "deepseek":
- _check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
- from langchain_deepseek import ChatDeepSeek
- return ChatDeepSeek(model=model, **kwargs)
- if model_provider == "nvidia":
- _check_pkg("langchain_nvidia_ai_endpoints")
- from langchain_nvidia_ai_endpoints import ChatNVIDIA
- return ChatNVIDIA(model=model, **kwargs)
- if model_provider == "ibm":
- _check_pkg("langchain_ibm")
- from langchain_ibm import ChatWatsonx
- return ChatWatsonx(model_id=model, **kwargs)
- if model_provider == "xai":
- _check_pkg("langchain_xai")
- from langchain_xai import ChatXAI
- return ChatXAI(model=model, **kwargs)
- if model_provider == "perplexity":
- _check_pkg("langchain_perplexity")
- from langchain_perplexity import ChatPerplexity
- return ChatPerplexity(model=model, **kwargs)
- if model_provider == "upstage":
- _check_pkg("langchain_upstage")
- from langchain_upstage import ChatUpstage
- return ChatUpstage(model=model, **kwargs)
- supported = ", ".join(_SUPPORTED_PROVIDERS)
- msg = f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
- raise ValueError(msg)
- _SUPPORTED_PROVIDERS = {
- "openai",
- "anthropic",
- "azure_openai",
- "azure_ai",
- "cohere",
- "google_vertexai",
- "google_genai",
- "fireworks",
- "ollama",
- "together",
- "mistralai",
- "huggingface",
- "groq",
- "bedrock",
- "bedrock_converse",
- "google_anthropic_vertex",
- "deepseek",
- "ibm",
- "xai",
- "perplexity",
- "upstage",
- }
- def _attempt_infer_model_provider(model_name: str) -> str | None:
- if any(model_name.startswith(pre) for pre in ("gpt-", "o1", "o3")):
- return "openai"
- if model_name.startswith("claude"):
- return "anthropic"
- if model_name.startswith("command"):
- return "cohere"
- if model_name.startswith("accounts/fireworks"):
- return "fireworks"
- if model_name.startswith("gemini"):
- return "google_vertexai"
- if model_name.startswith("amazon."):
- return "bedrock"
- if model_name.startswith("mistral"):
- return "mistralai"
- if model_name.startswith("deepseek"):
- return "deepseek"
- if model_name.startswith("grok"):
- return "xai"
- if model_name.startswith("sonar"):
- return "perplexity"
- if model_name.startswith("solar"):
- return "upstage"
- return None
- def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
- if not model_provider and ":" in model and model.split(":")[0] in _SUPPORTED_PROVIDERS:
- model_provider = model.split(":")[0]
- model = ":".join(model.split(":")[1:])
- model_provider = model_provider or _attempt_infer_model_provider(model)
- if not model_provider:
- msg = (
- f"Unable to infer model provider for {model=}, please specify model_provider directly."
- )
- raise ValueError(msg)
- model_provider = model_provider.replace("-", "_").lower()
- return model, model_provider
- def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None:
- if not util.find_spec(pkg):
- pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
- msg = f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
- raise ImportError(msg)
- def _remove_prefix(s: str, prefix: str) -> str:
- return s.removeprefix(prefix)
- _DECLARATIVE_METHODS = ("bind_tools", "with_structured_output")
- class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
- def __init__(
- self,
- *,
- default_config: dict | None = None,
- configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = "any",
- config_prefix: str = "",
- queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
- ) -> None:
- self._default_config: dict = default_config or {}
- self._configurable_fields: Literal["any"] | list[str] = (
- configurable_fields if configurable_fields == "any" else list(configurable_fields)
- )
- self._config_prefix = (
- config_prefix + "_"
- if config_prefix and not config_prefix.endswith("_")
- else config_prefix
- )
- self._queued_declarative_operations: list[tuple[str, tuple, dict]] = list(
- queued_declarative_operations,
- )
- def __getattr__(self, name: str) -> Any:
- if name in _DECLARATIVE_METHODS:
- # Declarative operations that cannot be applied until after an actual model
- # object is instantiated. So instead of returning the actual operation,
- # we record the operation and its arguments in a queue. This queue is
- # then applied in order whenever we actually instantiate the model (in
- # self._model()).
- def queue(*args: Any, **kwargs: Any) -> _ConfigurableModel:
- queued_declarative_operations = list(
- self._queued_declarative_operations,
- )
- queued_declarative_operations.append((name, args, kwargs))
- return _ConfigurableModel(
- default_config=dict(self._default_config),
- configurable_fields=list(self._configurable_fields)
- if isinstance(self._configurable_fields, list)
- else self._configurable_fields,
- config_prefix=self._config_prefix,
- queued_declarative_operations=queued_declarative_operations,
- )
- return queue
- if self._default_config and (model := self._model()) and hasattr(model, name):
- return getattr(model, name)
- msg = f"{name} is not a BaseChatModel attribute"
- if self._default_config:
- msg += " and is not implemented on the default model"
- msg += "."
- raise AttributeError(msg)
- def _model(self, config: RunnableConfig | None = None) -> Runnable:
- params = {**self._default_config, **self._model_params(config)}
- model = _init_chat_model_helper(**params)
- for name, args, kwargs in self._queued_declarative_operations:
- model = getattr(model, name)(*args, **kwargs)
- return model
- def _model_params(self, config: RunnableConfig | None) -> dict:
- config = ensure_config(config)
- model_params = {
- _remove_prefix(k, self._config_prefix): v
- for k, v in config.get("configurable", {}).items()
- if k.startswith(self._config_prefix)
- }
- if self._configurable_fields != "any":
- model_params = {k: v for k, v in model_params.items() if k in self._configurable_fields}
- return model_params
- def with_config(
- self,
- config: RunnableConfig | None = None,
- **kwargs: Any,
- ) -> _ConfigurableModel:
- """Bind config to a `Runnable`, returning a new `Runnable`."""
- config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs))
- model_params = self._model_params(config)
- remaining_config = {k: v for k, v in config.items() if k != "configurable"}
- remaining_config["configurable"] = {
- k: v
- for k, v in config.get("configurable", {}).items()
- if _remove_prefix(k, self._config_prefix) not in model_params
- }
- queued_declarative_operations = list(self._queued_declarative_operations)
- if remaining_config:
- queued_declarative_operations.append(
- (
- "with_config",
- (),
- {"config": remaining_config},
- ),
- )
- return _ConfigurableModel(
- default_config={**self._default_config, **model_params},
- configurable_fields=list(self._configurable_fields)
- if isinstance(self._configurable_fields, list)
- else self._configurable_fields,
- config_prefix=self._config_prefix,
- queued_declarative_operations=queued_declarative_operations,
- )
- @property
- def InputType(self) -> TypeAlias:
- """Get the input type for this `Runnable`."""
- from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue
- # This is a version of LanguageModelInput which replaces the abstract
- # base class BaseMessage with a union of its subclasses, which makes
- # for a much better schema.
- return str | StringPromptValue | ChatPromptValueConcrete | list[AnyMessage]
- @override
- def invoke(
- self,
- input: LanguageModelInput,
- config: RunnableConfig | None = None,
- **kwargs: Any,
- ) -> Any:
- return self._model(config).invoke(input, config=config, **kwargs)
- @override
- async def ainvoke(
- self,
- input: LanguageModelInput,
- config: RunnableConfig | None = None,
- **kwargs: Any,
- ) -> Any:
- return await self._model(config).ainvoke(input, config=config, **kwargs)
- @override
- def stream(
- self,
- input: LanguageModelInput,
- config: RunnableConfig | None = None,
- **kwargs: Any | None,
- ) -> Iterator[Any]:
- yield from self._model(config).stream(input, config=config, **kwargs)
- @override
- async def astream(
- self,
- input: LanguageModelInput,
- config: RunnableConfig | None = None,
- **kwargs: Any | None,
- ) -> AsyncIterator[Any]:
- async for x in self._model(config).astream(input, config=config, **kwargs):
- yield x
- def batch(
- self,
- inputs: list[LanguageModelInput],
- config: RunnableConfig | list[RunnableConfig] | None = None,
- *,
- return_exceptions: bool = False,
- **kwargs: Any | None,
- ) -> list[Any]:
- config = config or None
- # If <= 1 config use the underlying models batch implementation.
- if config is None or isinstance(config, dict) or len(config) <= 1:
- if isinstance(config, list):
- config = config[0]
- return self._model(config).batch(
- inputs,
- config=config,
- return_exceptions=return_exceptions,
- **kwargs,
- )
- # If multiple configs default to Runnable.batch which uses executor to invoke
- # in parallel.
- return super().batch(
- inputs,
- config=config,
- return_exceptions=return_exceptions,
- **kwargs,
- )
- async def abatch(
- self,
- inputs: list[LanguageModelInput],
- config: RunnableConfig | list[RunnableConfig] | None = None,
- *,
- return_exceptions: bool = False,
- **kwargs: Any | None,
- ) -> list[Any]:
- config = config or None
- # If <= 1 config use the underlying models batch implementation.
- if config is None or isinstance(config, dict) or len(config) <= 1:
- if isinstance(config, list):
- config = config[0]
- return await self._model(config).abatch(
- inputs,
- config=config,
- return_exceptions=return_exceptions,
- **kwargs,
- )
- # If multiple configs default to Runnable.batch which uses executor to invoke
- # in parallel.
- return await super().abatch(
- inputs,
- config=config,
- return_exceptions=return_exceptions,
- **kwargs,
- )
- def batch_as_completed(
- self,
- inputs: Sequence[LanguageModelInput],
- config: RunnableConfig | Sequence[RunnableConfig] | None = None,
- *,
- return_exceptions: bool = False,
- **kwargs: Any,
- ) -> Iterator[tuple[int, Any | Exception]]:
- config = config or None
- # If <= 1 config use the underlying models batch implementation.
- if config is None or isinstance(config, dict) or len(config) <= 1:
- if isinstance(config, list):
- config = config[0]
- yield from self._model(cast("RunnableConfig", config)).batch_as_completed( # type: ignore[call-overload]
- inputs,
- config=config,
- return_exceptions=return_exceptions,
- **kwargs,
- )
- # If multiple configs default to Runnable.batch which uses executor to invoke
- # in parallel.
- else:
- yield from super().batch_as_completed( # type: ignore[call-overload]
- inputs,
- config=config,
- return_exceptions=return_exceptions,
- **kwargs,
- )
- async def abatch_as_completed(
- self,
- inputs: Sequence[LanguageModelInput],
- config: RunnableConfig | Sequence[RunnableConfig] | None = None,
- *,
- return_exceptions: bool = False,
- **kwargs: Any,
- ) -> AsyncIterator[tuple[int, Any]]:
- config = config or None
- # If <= 1 config use the underlying models batch implementation.
- if config is None or isinstance(config, dict) or len(config) <= 1:
- if isinstance(config, list):
- config = config[0]
- async for x in self._model(
- cast("RunnableConfig", config),
- ).abatch_as_completed( # type: ignore[call-overload]
- inputs,
- config=config,
- return_exceptions=return_exceptions,
- **kwargs,
- ):
- yield x
- # If multiple configs default to Runnable.batch which uses executor to invoke
- # in parallel.
- else:
- async for x in super().abatch_as_completed( # type: ignore[call-overload]
- inputs,
- config=config,
- return_exceptions=return_exceptions,
- **kwargs,
- ):
- yield x
- @override
- def transform(
- self,
- input: Iterator[LanguageModelInput],
- config: RunnableConfig | None = None,
- **kwargs: Any | None,
- ) -> Iterator[Any]:
- yield from self._model(config).transform(input, config=config, **kwargs)
- @override
- async def atransform(
- self,
- input: AsyncIterator[LanguageModelInput],
- config: RunnableConfig | None = None,
- **kwargs: Any | None,
- ) -> AsyncIterator[Any]:
- async for x in self._model(config).atransform(input, config=config, **kwargs):
- yield x
- @overload
- def astream_log(
- self,
- input: Any,
- config: RunnableConfig | None = None,
- *,
- diff: Literal[True] = True,
- with_streamed_output_list: bool = True,
- include_names: Sequence[str] | None = None,
- include_types: Sequence[str] | None = None,
- include_tags: Sequence[str] | None = None,
- exclude_names: Sequence[str] | None = None,
- exclude_types: Sequence[str] | None = None,
- exclude_tags: Sequence[str] | None = None,
- **kwargs: Any,
- ) -> AsyncIterator[RunLogPatch]: ...
- @overload
- def astream_log(
- self,
- input: Any,
- config: RunnableConfig | None = None,
- *,
- diff: Literal[False],
- with_streamed_output_list: bool = True,
- include_names: Sequence[str] | None = None,
- include_types: Sequence[str] | None = None,
- include_tags: Sequence[str] | None = None,
- exclude_names: Sequence[str] | None = None,
- exclude_types: Sequence[str] | None = None,
- exclude_tags: Sequence[str] | None = None,
- **kwargs: Any,
- ) -> AsyncIterator[RunLog]: ...
- @override
- async def astream_log(
- self,
- input: Any,
- config: RunnableConfig | None = None,
- *,
- diff: bool = True,
- with_streamed_output_list: bool = True,
- include_names: Sequence[str] | None = None,
- include_types: Sequence[str] | None = None,
- include_tags: Sequence[str] | None = None,
- exclude_names: Sequence[str] | None = None,
- exclude_types: Sequence[str] | None = None,
- exclude_tags: Sequence[str] | None = None,
- **kwargs: Any,
- ) -> AsyncIterator[RunLogPatch] | AsyncIterator[RunLog]:
- async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
- input,
- config=config,
- diff=diff,
- with_streamed_output_list=with_streamed_output_list,
- include_names=include_names,
- include_types=include_types,
- include_tags=include_tags,
- exclude_tags=exclude_tags,
- exclude_types=exclude_types,
- exclude_names=exclude_names,
- **kwargs,
- ):
- yield x
- @override
- async def astream_events(
- self,
- input: Any,
- config: RunnableConfig | None = None,
- *,
- version: Literal["v1", "v2"] = "v2",
- include_names: Sequence[str] | None = None,
- include_types: Sequence[str] | None = None,
- include_tags: Sequence[str] | None = None,
- exclude_names: Sequence[str] | None = None,
- exclude_types: Sequence[str] | None = None,
- exclude_tags: Sequence[str] | None = None,
- **kwargs: Any,
- ) -> AsyncIterator[StreamEvent]:
- async for x in self._model(config).astream_events(
- input,
- config=config,
- version=version,
- include_names=include_names,
- include_types=include_types,
- include_tags=include_tags,
- exclude_tags=exclude_tags,
- exclude_types=exclude_types,
- exclude_names=exclude_names,
- **kwargs,
- ):
- yield x
- # Explicitly added to satisfy downstream linters.
- def bind_tools(
- self,
- tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
- **kwargs: Any,
- ) -> Runnable[LanguageModelInput, AIMessage]:
- return self.__getattr__("bind_tools")(tools, **kwargs)
- # Explicitly added to satisfy downstream linters.
- def with_structured_output(
- self,
- schema: dict | type[BaseModel],
- **kwargs: Any,
- ) -> Runnable[LanguageModelInput, dict | BaseModel]:
- return self.__getattr__("with_structured_output")(schema, **kwargs)
|