base.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937
  1. """Factory functions for chat models."""
  2. from __future__ import annotations
  3. import warnings
  4. from importlib import util
  5. from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast, overload
  6. from langchain_core.language_models import BaseChatModel, LanguageModelInput
  7. from langchain_core.messages import AIMessage, AnyMessage
  8. from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
  9. from typing_extensions import override
  10. if TYPE_CHECKING:
  11. from collections.abc import AsyncIterator, Callable, Iterator, Sequence
  12. from langchain_core.runnables.schema import StreamEvent
  13. from langchain_core.tools import BaseTool
  14. from langchain_core.tracers import RunLog, RunLogPatch
  15. from pydantic import BaseModel
  16. @overload
  17. def init_chat_model(
  18. model: str,
  19. *,
  20. model_provider: str | None = None,
  21. configurable_fields: None = None,
  22. config_prefix: str | None = None,
  23. **kwargs: Any,
  24. ) -> BaseChatModel: ...
  25. @overload
  26. def init_chat_model(
  27. model: None = None,
  28. *,
  29. model_provider: str | None = None,
  30. configurable_fields: None = None,
  31. config_prefix: str | None = None,
  32. **kwargs: Any,
  33. ) -> _ConfigurableModel: ...
  34. @overload
  35. def init_chat_model(
  36. model: str | None = None,
  37. *,
  38. model_provider: str | None = None,
  39. configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = ...,
  40. config_prefix: str | None = None,
  41. **kwargs: Any,
  42. ) -> _ConfigurableModel: ...
  43. # FOR CONTRIBUTORS: If adding support for a new provider, please append the provider
  44. # name to the supported list in the docstring below. Do *not* change the order of the
  45. # existing providers.
  46. def init_chat_model(
  47. model: str | None = None,
  48. *,
  49. model_provider: str | None = None,
  50. configurable_fields: Literal["any"] | list[str] | tuple[str, ...] | None = None,
  51. config_prefix: str | None = None,
  52. **kwargs: Any,
  53. ) -> BaseChatModel | _ConfigurableModel:
  54. """Initialize a chat model from any supported provider using a unified interface.
  55. **Two main use cases:**
  56. 1. **Fixed model** – specify the model upfront and get a ready-to-use chat model.
  57. 2. **Configurable model** – choose to specify parameters (including model name) at
  58. runtime via `config`. Makes it easy to switch between models/providers without
  59. changing your code
  60. !!! note
  61. Requires the integration package for the chosen model provider to be installed.
  62. See the `model_provider` parameter below for specific package names
  63. (e.g., `pip install langchain-openai`).
  64. Refer to the [provider integration's API reference](https://docs.langchain.com/oss/python/integrations/providers)
  65. for supported model parameters to use as `**kwargs`.
  66. Args:
  67. model: The name or ID of the model, e.g. `'o3-mini'`, `'claude-sonnet-4-5-20250929'`.
  68. You can also specify model and model provider in a single argument using
  69. `'{model_provider}:{model}'` format, e.g. `'openai:o1'`.
  70. Will attempt to infer `model_provider` from model if not specified.
  71. The following providers will be inferred based on these model prefixes:
  72. - `gpt-...` | `o1...` | `o3...` -> `openai`
  73. - `claude...` -> `anthropic`
  74. - `amazon...` -> `bedrock`
  75. - `gemini...` -> `google_vertexai`
  76. - `command...` -> `cohere`
  77. - `accounts/fireworks...` -> `fireworks`
  78. - `mistral...` -> `mistralai`
  79. - `deepseek...` -> `deepseek`
  80. - `grok...` -> `xai`
  81. - `sonar...` -> `perplexity`
  82. - `solar...` -> `upstage`
  83. model_provider: The model provider if not specified as part of the model arg
  84. (see above).
  85. Supported `model_provider` values and the corresponding integration package
  86. are:
  87. - `openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
  88. - `anthropic` -> [`langchain-anthropic`](https://docs.langchain.com/oss/python/integrations/providers/anthropic)
  89. - `azure_openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
  90. - `azure_ai` -> [`langchain-azure-ai`](https://docs.langchain.com/oss/python/integrations/providers/microsoft)
  91. - `google_vertexai` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
  92. - `google_genai` -> [`langchain-google-genai`](https://docs.langchain.com/oss/python/integrations/providers/google)
  93. - `bedrock` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
  94. - `bedrock_converse` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
  95. - `cohere` -> [`langchain-cohere`](https://docs.langchain.com/oss/python/integrations/providers/cohere)
  96. - `fireworks` -> [`langchain-fireworks`](https://docs.langchain.com/oss/python/integrations/providers/fireworks)
  97. - `together` -> [`langchain-together`](https://docs.langchain.com/oss/python/integrations/providers/together)
  98. - `mistralai` -> [`langchain-mistralai`](https://docs.langchain.com/oss/python/integrations/providers/mistralai)
  99. - `huggingface` -> [`langchain-huggingface`](https://docs.langchain.com/oss/python/integrations/providers/huggingface)
  100. - `groq` -> [`langchain-groq`](https://docs.langchain.com/oss/python/integrations/providers/groq)
  101. - `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama)
  102. - `google_anthropic_vertex` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
  103. - `deepseek` -> [`langchain-deepseek`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
  104. - `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/ibm)
  105. - `nvidia` -> [`langchain-nvidia-ai-endpoints`](https://docs.langchain.com/oss/python/integrations/providers/nvidia)
  106. - `xai` -> [`langchain-xai`](https://docs.langchain.com/oss/python/integrations/providers/xai)
  107. - `perplexity` -> [`langchain-perplexity`](https://docs.langchain.com/oss/python/integrations/providers/perplexity)
  108. - `upstage` -> [`langchain-upstage`](https://docs.langchain.com/oss/python/integrations/providers/upstage)
  109. configurable_fields: Which model parameters are configurable at runtime:
  110. - `None`: No configurable fields (i.e., a fixed model).
  111. - `'any'`: All fields are configurable. **See security note below.**
  112. - `list[str] | Tuple[str, ...]`: Specified fields are configurable.
  113. Fields are assumed to have `config_prefix` stripped if a `config_prefix` is
  114. specified.
  115. If `model` is specified, then defaults to `None`.
  116. If `model` is not specified, then defaults to `("model", "model_provider")`.
  117. !!! warning "Security note"
  118. Setting `configurable_fields="any"` means fields like `api_key`,
  119. `base_url`, etc., can be altered at runtime, potentially redirecting
  120. model requests to a different service/user.
  121. Make sure that if you're accepting untrusted configurations that you
  122. enumerate the `configurable_fields=(...)` explicitly.
  123. config_prefix: Optional prefix for configuration keys.
  124. Useful when you have multiple configurable models in the same application.
  125. If `'config_prefix'` is a non-empty string then `model` will be configurable
  126. at runtime via the `config["configurable"]["{config_prefix}_{param}"]` keys.
  127. See examples below.
  128. If `'config_prefix'` is an empty string then model will be configurable via
  129. `config["configurable"]["{param}"]`.
  130. **kwargs: Additional model-specific keyword args to pass to the underlying
  131. chat model's `__init__` method. Common parameters include:
  132. - `temperature`: Model temperature for controlling randomness.
  133. - `max_tokens`: Maximum number of output tokens.
  134. - `timeout`: Maximum time (in seconds) to wait for a response.
  135. - `max_retries`: Maximum number of retry attempts for failed requests.
  136. - `base_url`: Custom API endpoint URL.
  137. - `rate_limiter`: A
  138. [`BaseRateLimiter`][langchain_core.rate_limiters.BaseRateLimiter]
  139. instance to control request rate.
  140. Refer to the specific model provider's
  141. [integration reference](https://reference.langchain.com/python/integrations/)
  142. for all available parameters.
  143. Returns:
  144. A `BaseChatModel` corresponding to the `model_name` and `model_provider`
  145. specified if configurability is inferred to be `False`. If configurable, a
  146. chat model emulator that initializes the underlying model at runtime once a
  147. config is passed in.
  148. Raises:
  149. ValueError: If `model_provider` cannot be inferred or isn't supported.
  150. ImportError: If the model provider integration package is not installed.
  151. ???+ example "Initialize a non-configurable model"
  152. ```python
  153. # pip install langchain langchain-openai langchain-anthropic langchain-google-vertexai
  154. from langchain.chat_models import init_chat_model
  155. o3_mini = init_chat_model("openai:o3-mini", temperature=0)
  156. claude_sonnet = init_chat_model("anthropic:claude-sonnet-4-5-20250929", temperature=0)
  157. gemini_2-5_flash = init_chat_model("google_vertexai:gemini-2.5-flash", temperature=0)
  158. o3_mini.invoke("what's your name")
  159. claude_sonnet.invoke("what's your name")
  160. gemini_2-5_flash.invoke("what's your name")
  161. ```
  162. ??? example "Partially configurable model with no default"
  163. ```python
  164. # pip install langchain langchain-openai langchain-anthropic
  165. from langchain.chat_models import init_chat_model
  166. # (We don't need to specify configurable=True if a model isn't specified.)
  167. configurable_model = init_chat_model(temperature=0)
  168. configurable_model.invoke("what's your name", config={"configurable": {"model": "gpt-4o"}})
  169. # Use GPT-4o to generate the response
  170. configurable_model.invoke(
  171. "what's your name",
  172. config={"configurable": {"model": "claude-sonnet-4-5-20250929"}},
  173. )
  174. ```
  175. ??? example "Fully configurable model with a default"
  176. ```python
  177. # pip install langchain langchain-openai langchain-anthropic
  178. from langchain.chat_models import init_chat_model
  179. configurable_model_with_default = init_chat_model(
  180. "openai:gpt-4o",
  181. configurable_fields="any", # This allows us to configure other params like temperature, max_tokens, etc at runtime.
  182. config_prefix="foo",
  183. temperature=0,
  184. )
  185. configurable_model_with_default.invoke("what's your name")
  186. # GPT-4o response with temperature 0 (as set in default)
  187. configurable_model_with_default.invoke(
  188. "what's your name",
  189. config={
  190. "configurable": {
  191. "foo_model": "anthropic:claude-sonnet-4-5-20250929",
  192. "foo_temperature": 0.6,
  193. }
  194. },
  195. )
  196. # Override default to use Sonnet 4.5 with temperature 0.6 to generate response
  197. ```
  198. ??? example "Bind tools to a configurable model"
  199. You can call any chat model declarative methods on a configurable model in the
  200. same way that you would with a normal model:
  201. ```python
  202. # pip install langchain langchain-openai langchain-anthropic
  203. from langchain.chat_models import init_chat_model
  204. from pydantic import BaseModel, Field
  205. class GetWeather(BaseModel):
  206. '''Get the current weather in a given location'''
  207. location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
  208. class GetPopulation(BaseModel):
  209. '''Get the current population in a given location'''
  210. location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
  211. configurable_model = init_chat_model(
  212. "gpt-4o", configurable_fields=("model", "model_provider"), temperature=0
  213. )
  214. configurable_model_with_tools = configurable_model.bind_tools(
  215. [
  216. GetWeather,
  217. GetPopulation,
  218. ]
  219. )
  220. configurable_model_with_tools.invoke(
  221. "Which city is hotter today and which is bigger: LA or NY?"
  222. )
  223. # Use GPT-4o
  224. configurable_model_with_tools.invoke(
  225. "Which city is hotter today and which is bigger: LA or NY?",
  226. config={"configurable": {"model": "claude-sonnet-4-5-20250929"}},
  227. )
  228. # Use Sonnet 4.5
  229. ```
  230. """ # noqa: E501
  231. if not model and not configurable_fields:
  232. configurable_fields = ("model", "model_provider")
  233. config_prefix = config_prefix or ""
  234. if config_prefix and not configurable_fields:
  235. warnings.warn(
  236. f"{config_prefix=} has been set but no fields are configurable. Set "
  237. f"`configurable_fields=(...)` to specify the model params that are "
  238. f"configurable.",
  239. stacklevel=2,
  240. )
  241. if not configurable_fields:
  242. return _init_chat_model_helper(
  243. cast("str", model),
  244. model_provider=model_provider,
  245. **kwargs,
  246. )
  247. if model:
  248. kwargs["model"] = model
  249. if model_provider:
  250. kwargs["model_provider"] = model_provider
  251. return _ConfigurableModel(
  252. default_config=kwargs,
  253. config_prefix=config_prefix,
  254. configurable_fields=configurable_fields,
  255. )
  256. def _init_chat_model_helper(
  257. model: str,
  258. *,
  259. model_provider: str | None = None,
  260. **kwargs: Any,
  261. ) -> BaseChatModel:
  262. model, model_provider = _parse_model(model, model_provider)
  263. if model_provider == "openai":
  264. _check_pkg("langchain_openai")
  265. from langchain_openai import ChatOpenAI
  266. return ChatOpenAI(model=model, **kwargs)
  267. if model_provider == "anthropic":
  268. _check_pkg("langchain_anthropic")
  269. from langchain_anthropic import ChatAnthropic
  270. return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
  271. if model_provider == "azure_openai":
  272. _check_pkg("langchain_openai")
  273. from langchain_openai import AzureChatOpenAI
  274. return AzureChatOpenAI(model=model, **kwargs)
  275. if model_provider == "azure_ai":
  276. _check_pkg("langchain_azure_ai")
  277. from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
  278. return AzureAIChatCompletionsModel(model=model, **kwargs)
  279. if model_provider == "cohere":
  280. _check_pkg("langchain_cohere")
  281. from langchain_cohere import ChatCohere
  282. return ChatCohere(model=model, **kwargs)
  283. if model_provider == "google_vertexai":
  284. _check_pkg("langchain_google_vertexai")
  285. from langchain_google_vertexai import ChatVertexAI
  286. return ChatVertexAI(model=model, **kwargs)
  287. if model_provider == "google_genai":
  288. _check_pkg("langchain_google_genai")
  289. from langchain_google_genai import ChatGoogleGenerativeAI
  290. return ChatGoogleGenerativeAI(model=model, **kwargs)
  291. if model_provider == "fireworks":
  292. _check_pkg("langchain_fireworks")
  293. from langchain_fireworks import ChatFireworks
  294. return ChatFireworks(model=model, **kwargs)
  295. if model_provider == "ollama":
  296. try:
  297. _check_pkg("langchain_ollama")
  298. from langchain_ollama import ChatOllama
  299. except ImportError:
  300. # For backwards compatibility
  301. try:
  302. _check_pkg("langchain_community")
  303. from langchain_community.chat_models import ChatOllama
  304. except ImportError:
  305. # If both langchain-ollama and langchain-community aren't available,
  306. # raise an error related to langchain-ollama
  307. _check_pkg("langchain_ollama")
  308. return ChatOllama(model=model, **kwargs)
  309. if model_provider == "together":
  310. _check_pkg("langchain_together")
  311. from langchain_together import ChatTogether
  312. return ChatTogether(model=model, **kwargs)
  313. if model_provider == "mistralai":
  314. _check_pkg("langchain_mistralai")
  315. from langchain_mistralai import ChatMistralAI
  316. return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
  317. if model_provider == "huggingface":
  318. _check_pkg("langchain_huggingface")
  319. from langchain_huggingface import ChatHuggingFace
  320. return ChatHuggingFace(model_id=model, **kwargs)
  321. if model_provider == "groq":
  322. _check_pkg("langchain_groq")
  323. from langchain_groq import ChatGroq
  324. return ChatGroq(model=model, **kwargs)
  325. if model_provider == "bedrock":
  326. _check_pkg("langchain_aws")
  327. from langchain_aws import ChatBedrock
  328. return ChatBedrock(model_id=model, **kwargs)
  329. if model_provider == "bedrock_converse":
  330. _check_pkg("langchain_aws")
  331. from langchain_aws import ChatBedrockConverse
  332. return ChatBedrockConverse(model=model, **kwargs)
  333. if model_provider == "google_anthropic_vertex":
  334. _check_pkg("langchain_google_vertexai")
  335. from langchain_google_vertexai.model_garden import ChatAnthropicVertex
  336. return ChatAnthropicVertex(model=model, **kwargs)
  337. if model_provider == "deepseek":
  338. _check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
  339. from langchain_deepseek import ChatDeepSeek
  340. return ChatDeepSeek(model=model, **kwargs)
  341. if model_provider == "nvidia":
  342. _check_pkg("langchain_nvidia_ai_endpoints")
  343. from langchain_nvidia_ai_endpoints import ChatNVIDIA
  344. return ChatNVIDIA(model=model, **kwargs)
  345. if model_provider == "ibm":
  346. _check_pkg("langchain_ibm")
  347. from langchain_ibm import ChatWatsonx
  348. return ChatWatsonx(model_id=model, **kwargs)
  349. if model_provider == "xai":
  350. _check_pkg("langchain_xai")
  351. from langchain_xai import ChatXAI
  352. return ChatXAI(model=model, **kwargs)
  353. if model_provider == "perplexity":
  354. _check_pkg("langchain_perplexity")
  355. from langchain_perplexity import ChatPerplexity
  356. return ChatPerplexity(model=model, **kwargs)
  357. if model_provider == "upstage":
  358. _check_pkg("langchain_upstage")
  359. from langchain_upstage import ChatUpstage
  360. return ChatUpstage(model=model, **kwargs)
  361. supported = ", ".join(_SUPPORTED_PROVIDERS)
  362. msg = f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
  363. raise ValueError(msg)
  364. _SUPPORTED_PROVIDERS = {
  365. "openai",
  366. "anthropic",
  367. "azure_openai",
  368. "azure_ai",
  369. "cohere",
  370. "google_vertexai",
  371. "google_genai",
  372. "fireworks",
  373. "ollama",
  374. "together",
  375. "mistralai",
  376. "huggingface",
  377. "groq",
  378. "bedrock",
  379. "bedrock_converse",
  380. "google_anthropic_vertex",
  381. "deepseek",
  382. "ibm",
  383. "xai",
  384. "perplexity",
  385. "upstage",
  386. }
  387. def _attempt_infer_model_provider(model_name: str) -> str | None:
  388. if any(model_name.startswith(pre) for pre in ("gpt-", "o1", "o3")):
  389. return "openai"
  390. if model_name.startswith("claude"):
  391. return "anthropic"
  392. if model_name.startswith("command"):
  393. return "cohere"
  394. if model_name.startswith("accounts/fireworks"):
  395. return "fireworks"
  396. if model_name.startswith("gemini"):
  397. return "google_vertexai"
  398. if model_name.startswith("amazon."):
  399. return "bedrock"
  400. if model_name.startswith("mistral"):
  401. return "mistralai"
  402. if model_name.startswith("deepseek"):
  403. return "deepseek"
  404. if model_name.startswith("grok"):
  405. return "xai"
  406. if model_name.startswith("sonar"):
  407. return "perplexity"
  408. if model_name.startswith("solar"):
  409. return "upstage"
  410. return None
  411. def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
  412. if not model_provider and ":" in model and model.split(":")[0] in _SUPPORTED_PROVIDERS:
  413. model_provider = model.split(":")[0]
  414. model = ":".join(model.split(":")[1:])
  415. model_provider = model_provider or _attempt_infer_model_provider(model)
  416. if not model_provider:
  417. msg = (
  418. f"Unable to infer model provider for {model=}, please specify model_provider directly."
  419. )
  420. raise ValueError(msg)
  421. model_provider = model_provider.replace("-", "_").lower()
  422. return model, model_provider
  423. def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None:
  424. if not util.find_spec(pkg):
  425. pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
  426. msg = f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
  427. raise ImportError(msg)
  428. def _remove_prefix(s: str, prefix: str) -> str:
  429. return s.removeprefix(prefix)
  430. _DECLARATIVE_METHODS = ("bind_tools", "with_structured_output")
  431. class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
  432. def __init__(
  433. self,
  434. *,
  435. default_config: dict | None = None,
  436. configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = "any",
  437. config_prefix: str = "",
  438. queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
  439. ) -> None:
  440. self._default_config: dict = default_config or {}
  441. self._configurable_fields: Literal["any"] | list[str] = (
  442. configurable_fields if configurable_fields == "any" else list(configurable_fields)
  443. )
  444. self._config_prefix = (
  445. config_prefix + "_"
  446. if config_prefix and not config_prefix.endswith("_")
  447. else config_prefix
  448. )
  449. self._queued_declarative_operations: list[tuple[str, tuple, dict]] = list(
  450. queued_declarative_operations,
  451. )
  452. def __getattr__(self, name: str) -> Any:
  453. if name in _DECLARATIVE_METHODS:
  454. # Declarative operations that cannot be applied until after an actual model
  455. # object is instantiated. So instead of returning the actual operation,
  456. # we record the operation and its arguments in a queue. This queue is
  457. # then applied in order whenever we actually instantiate the model (in
  458. # self._model()).
  459. def queue(*args: Any, **kwargs: Any) -> _ConfigurableModel:
  460. queued_declarative_operations = list(
  461. self._queued_declarative_operations,
  462. )
  463. queued_declarative_operations.append((name, args, kwargs))
  464. return _ConfigurableModel(
  465. default_config=dict(self._default_config),
  466. configurable_fields=list(self._configurable_fields)
  467. if isinstance(self._configurable_fields, list)
  468. else self._configurable_fields,
  469. config_prefix=self._config_prefix,
  470. queued_declarative_operations=queued_declarative_operations,
  471. )
  472. return queue
  473. if self._default_config and (model := self._model()) and hasattr(model, name):
  474. return getattr(model, name)
  475. msg = f"{name} is not a BaseChatModel attribute"
  476. if self._default_config:
  477. msg += " and is not implemented on the default model"
  478. msg += "."
  479. raise AttributeError(msg)
  480. def _model(self, config: RunnableConfig | None = None) -> Runnable:
  481. params = {**self._default_config, **self._model_params(config)}
  482. model = _init_chat_model_helper(**params)
  483. for name, args, kwargs in self._queued_declarative_operations:
  484. model = getattr(model, name)(*args, **kwargs)
  485. return model
  486. def _model_params(self, config: RunnableConfig | None) -> dict:
  487. config = ensure_config(config)
  488. model_params = {
  489. _remove_prefix(k, self._config_prefix): v
  490. for k, v in config.get("configurable", {}).items()
  491. if k.startswith(self._config_prefix)
  492. }
  493. if self._configurable_fields != "any":
  494. model_params = {k: v for k, v in model_params.items() if k in self._configurable_fields}
  495. return model_params
  496. def with_config(
  497. self,
  498. config: RunnableConfig | None = None,
  499. **kwargs: Any,
  500. ) -> _ConfigurableModel:
  501. """Bind config to a `Runnable`, returning a new `Runnable`."""
  502. config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs))
  503. model_params = self._model_params(config)
  504. remaining_config = {k: v for k, v in config.items() if k != "configurable"}
  505. remaining_config["configurable"] = {
  506. k: v
  507. for k, v in config.get("configurable", {}).items()
  508. if _remove_prefix(k, self._config_prefix) not in model_params
  509. }
  510. queued_declarative_operations = list(self._queued_declarative_operations)
  511. if remaining_config:
  512. queued_declarative_operations.append(
  513. (
  514. "with_config",
  515. (),
  516. {"config": remaining_config},
  517. ),
  518. )
  519. return _ConfigurableModel(
  520. default_config={**self._default_config, **model_params},
  521. configurable_fields=list(self._configurable_fields)
  522. if isinstance(self._configurable_fields, list)
  523. else self._configurable_fields,
  524. config_prefix=self._config_prefix,
  525. queued_declarative_operations=queued_declarative_operations,
  526. )
  527. @property
  528. def InputType(self) -> TypeAlias:
  529. """Get the input type for this `Runnable`."""
  530. from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue
  531. # This is a version of LanguageModelInput which replaces the abstract
  532. # base class BaseMessage with a union of its subclasses, which makes
  533. # for a much better schema.
  534. return str | StringPromptValue | ChatPromptValueConcrete | list[AnyMessage]
  535. @override
  536. def invoke(
  537. self,
  538. input: LanguageModelInput,
  539. config: RunnableConfig | None = None,
  540. **kwargs: Any,
  541. ) -> Any:
  542. return self._model(config).invoke(input, config=config, **kwargs)
  543. @override
  544. async def ainvoke(
  545. self,
  546. input: LanguageModelInput,
  547. config: RunnableConfig | None = None,
  548. **kwargs: Any,
  549. ) -> Any:
  550. return await self._model(config).ainvoke(input, config=config, **kwargs)
  551. @override
  552. def stream(
  553. self,
  554. input: LanguageModelInput,
  555. config: RunnableConfig | None = None,
  556. **kwargs: Any | None,
  557. ) -> Iterator[Any]:
  558. yield from self._model(config).stream(input, config=config, **kwargs)
  559. @override
  560. async def astream(
  561. self,
  562. input: LanguageModelInput,
  563. config: RunnableConfig | None = None,
  564. **kwargs: Any | None,
  565. ) -> AsyncIterator[Any]:
  566. async for x in self._model(config).astream(input, config=config, **kwargs):
  567. yield x
  568. def batch(
  569. self,
  570. inputs: list[LanguageModelInput],
  571. config: RunnableConfig | list[RunnableConfig] | None = None,
  572. *,
  573. return_exceptions: bool = False,
  574. **kwargs: Any | None,
  575. ) -> list[Any]:
  576. config = config or None
  577. # If <= 1 config use the underlying models batch implementation.
  578. if config is None or isinstance(config, dict) or len(config) <= 1:
  579. if isinstance(config, list):
  580. config = config[0]
  581. return self._model(config).batch(
  582. inputs,
  583. config=config,
  584. return_exceptions=return_exceptions,
  585. **kwargs,
  586. )
  587. # If multiple configs default to Runnable.batch which uses executor to invoke
  588. # in parallel.
  589. return super().batch(
  590. inputs,
  591. config=config,
  592. return_exceptions=return_exceptions,
  593. **kwargs,
  594. )
  595. async def abatch(
  596. self,
  597. inputs: list[LanguageModelInput],
  598. config: RunnableConfig | list[RunnableConfig] | None = None,
  599. *,
  600. return_exceptions: bool = False,
  601. **kwargs: Any | None,
  602. ) -> list[Any]:
  603. config = config or None
  604. # If <= 1 config use the underlying models batch implementation.
  605. if config is None or isinstance(config, dict) or len(config) <= 1:
  606. if isinstance(config, list):
  607. config = config[0]
  608. return await self._model(config).abatch(
  609. inputs,
  610. config=config,
  611. return_exceptions=return_exceptions,
  612. **kwargs,
  613. )
  614. # If multiple configs default to Runnable.batch which uses executor to invoke
  615. # in parallel.
  616. return await super().abatch(
  617. inputs,
  618. config=config,
  619. return_exceptions=return_exceptions,
  620. **kwargs,
  621. )
  622. def batch_as_completed(
  623. self,
  624. inputs: Sequence[LanguageModelInput],
  625. config: RunnableConfig | Sequence[RunnableConfig] | None = None,
  626. *,
  627. return_exceptions: bool = False,
  628. **kwargs: Any,
  629. ) -> Iterator[tuple[int, Any | Exception]]:
  630. config = config or None
  631. # If <= 1 config use the underlying models batch implementation.
  632. if config is None or isinstance(config, dict) or len(config) <= 1:
  633. if isinstance(config, list):
  634. config = config[0]
  635. yield from self._model(cast("RunnableConfig", config)).batch_as_completed( # type: ignore[call-overload]
  636. inputs,
  637. config=config,
  638. return_exceptions=return_exceptions,
  639. **kwargs,
  640. )
  641. # If multiple configs default to Runnable.batch which uses executor to invoke
  642. # in parallel.
  643. else:
  644. yield from super().batch_as_completed( # type: ignore[call-overload]
  645. inputs,
  646. config=config,
  647. return_exceptions=return_exceptions,
  648. **kwargs,
  649. )
  650. async def abatch_as_completed(
  651. self,
  652. inputs: Sequence[LanguageModelInput],
  653. config: RunnableConfig | Sequence[RunnableConfig] | None = None,
  654. *,
  655. return_exceptions: bool = False,
  656. **kwargs: Any,
  657. ) -> AsyncIterator[tuple[int, Any]]:
  658. config = config or None
  659. # If <= 1 config use the underlying models batch implementation.
  660. if config is None or isinstance(config, dict) or len(config) <= 1:
  661. if isinstance(config, list):
  662. config = config[0]
  663. async for x in self._model(
  664. cast("RunnableConfig", config),
  665. ).abatch_as_completed( # type: ignore[call-overload]
  666. inputs,
  667. config=config,
  668. return_exceptions=return_exceptions,
  669. **kwargs,
  670. ):
  671. yield x
  672. # If multiple configs default to Runnable.batch which uses executor to invoke
  673. # in parallel.
  674. else:
  675. async for x in super().abatch_as_completed( # type: ignore[call-overload]
  676. inputs,
  677. config=config,
  678. return_exceptions=return_exceptions,
  679. **kwargs,
  680. ):
  681. yield x
  682. @override
  683. def transform(
  684. self,
  685. input: Iterator[LanguageModelInput],
  686. config: RunnableConfig | None = None,
  687. **kwargs: Any | None,
  688. ) -> Iterator[Any]:
  689. yield from self._model(config).transform(input, config=config, **kwargs)
  690. @override
  691. async def atransform(
  692. self,
  693. input: AsyncIterator[LanguageModelInput],
  694. config: RunnableConfig | None = None,
  695. **kwargs: Any | None,
  696. ) -> AsyncIterator[Any]:
  697. async for x in self._model(config).atransform(input, config=config, **kwargs):
  698. yield x
  699. @overload
  700. def astream_log(
  701. self,
  702. input: Any,
  703. config: RunnableConfig | None = None,
  704. *,
  705. diff: Literal[True] = True,
  706. with_streamed_output_list: bool = True,
  707. include_names: Sequence[str] | None = None,
  708. include_types: Sequence[str] | None = None,
  709. include_tags: Sequence[str] | None = None,
  710. exclude_names: Sequence[str] | None = None,
  711. exclude_types: Sequence[str] | None = None,
  712. exclude_tags: Sequence[str] | None = None,
  713. **kwargs: Any,
  714. ) -> AsyncIterator[RunLogPatch]: ...
  715. @overload
  716. def astream_log(
  717. self,
  718. input: Any,
  719. config: RunnableConfig | None = None,
  720. *,
  721. diff: Literal[False],
  722. with_streamed_output_list: bool = True,
  723. include_names: Sequence[str] | None = None,
  724. include_types: Sequence[str] | None = None,
  725. include_tags: Sequence[str] | None = None,
  726. exclude_names: Sequence[str] | None = None,
  727. exclude_types: Sequence[str] | None = None,
  728. exclude_tags: Sequence[str] | None = None,
  729. **kwargs: Any,
  730. ) -> AsyncIterator[RunLog]: ...
  731. @override
  732. async def astream_log(
  733. self,
  734. input: Any,
  735. config: RunnableConfig | None = None,
  736. *,
  737. diff: bool = True,
  738. with_streamed_output_list: bool = True,
  739. include_names: Sequence[str] | None = None,
  740. include_types: Sequence[str] | None = None,
  741. include_tags: Sequence[str] | None = None,
  742. exclude_names: Sequence[str] | None = None,
  743. exclude_types: Sequence[str] | None = None,
  744. exclude_tags: Sequence[str] | None = None,
  745. **kwargs: Any,
  746. ) -> AsyncIterator[RunLogPatch] | AsyncIterator[RunLog]:
  747. async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
  748. input,
  749. config=config,
  750. diff=diff,
  751. with_streamed_output_list=with_streamed_output_list,
  752. include_names=include_names,
  753. include_types=include_types,
  754. include_tags=include_tags,
  755. exclude_tags=exclude_tags,
  756. exclude_types=exclude_types,
  757. exclude_names=exclude_names,
  758. **kwargs,
  759. ):
  760. yield x
  761. @override
  762. async def astream_events(
  763. self,
  764. input: Any,
  765. config: RunnableConfig | None = None,
  766. *,
  767. version: Literal["v1", "v2"] = "v2",
  768. include_names: Sequence[str] | None = None,
  769. include_types: Sequence[str] | None = None,
  770. include_tags: Sequence[str] | None = None,
  771. exclude_names: Sequence[str] | None = None,
  772. exclude_types: Sequence[str] | None = None,
  773. exclude_tags: Sequence[str] | None = None,
  774. **kwargs: Any,
  775. ) -> AsyncIterator[StreamEvent]:
  776. async for x in self._model(config).astream_events(
  777. input,
  778. config=config,
  779. version=version,
  780. include_names=include_names,
  781. include_types=include_types,
  782. include_tags=include_tags,
  783. exclude_tags=exclude_tags,
  784. exclude_types=exclude_types,
  785. exclude_names=exclude_names,
  786. **kwargs,
  787. ):
  788. yield x
  789. # Explicitly added to satisfy downstream linters.
  790. def bind_tools(
  791. self,
  792. tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
  793. **kwargs: Any,
  794. ) -> Runnable[LanguageModelInput, AIMessage]:
  795. return self.__getattr__("bind_tools")(tools, **kwargs)
  796. # Explicitly added to satisfy downstream linters.
  797. def with_structured_output(
  798. self,
  799. schema: dict | type[BaseModel],
  800. **kwargs: Any,
  801. ) -> Runnable[LanguageModelInput, dict | BaseModel]:
  802. return self.__getattr__("with_structured_output")(schema, **kwargs)