azure.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. """Azure OpenAI large language models. Not to be confused with chat models."""
  2. from __future__ import annotations
  3. import logging
  4. from collections.abc import Awaitable, Callable, Mapping
  5. from typing import Any, cast
  6. import openai
  7. from langchain_core.language_models import LangSmithParams
  8. from langchain_core.utils import from_env, secret_from_env
  9. from pydantic import Field, SecretStr, model_validator
  10. from typing_extensions import Self
  11. from langchain_openai.llms.base import BaseOpenAI
  12. logger = logging.getLogger(__name__)
  13. class AzureOpenAI(BaseOpenAI):
  14. """Azure-specific OpenAI large language models.
  15. To use, you should have the `openai` python package installed, and the
  16. environment variable `OPENAI_API_KEY` set with your API key.
  17. Any parameters that are valid to be passed to the openai.create call can be passed
  18. in, even if not explicitly saved on this class.
  19. Example:
  20. ```python
  21. from langchain_openai import AzureOpenAI
  22. openai = AzureOpenAI(model_name="gpt-3.5-turbo-instruct")
  23. ```
  24. """
  25. azure_endpoint: str | None = Field(
  26. default_factory=from_env("AZURE_OPENAI_ENDPOINT", default=None)
  27. )
  28. """Your Azure endpoint, including the resource.
  29. Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided.
  30. Example: `'https://example-resource.azure.openai.com/'`
  31. """
  32. deployment_name: str | None = Field(default=None, alias="azure_deployment")
  33. """A model deployment.
  34. If given sets the base client URL to include `/deployments/{azure_deployment}`.
  35. !!! note
  36. This means you won't be able to use non-deployment endpoints.
  37. """
  38. openai_api_version: str | None = Field(
  39. alias="api_version",
  40. default_factory=from_env("OPENAI_API_VERSION", default=None),
  41. )
  42. """Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
  43. # Check OPENAI_KEY for backwards compatibility.
  44. # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
  45. # other forms of azure credentials.
  46. openai_api_key: SecretStr | None = Field(
  47. alias="api_key",
  48. default_factory=secret_from_env(
  49. ["AZURE_OPENAI_API_KEY", "OPENAI_API_KEY"], default=None
  50. ),
  51. )
  52. azure_ad_token: SecretStr | None = Field(
  53. default_factory=secret_from_env("AZURE_OPENAI_AD_TOKEN", default=None)
  54. )
  55. """Your Azure Active Directory token.
  56. Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
  57. `For more, see this page <https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id>.`__
  58. """
  59. azure_ad_token_provider: Callable[[], str] | None = None
  60. """A function that returns an Azure Active Directory token.
  61. Will be invoked on every sync request. For async requests,
  62. will be invoked if `azure_ad_async_token_provider` is not provided.
  63. """
  64. azure_ad_async_token_provider: Callable[[], Awaitable[str]] | None = None
  65. """A function that returns an Azure Active Directory token.
  66. Will be invoked on every async request.
  67. """
  68. openai_api_type: str | None = Field(
  69. default_factory=from_env("OPENAI_API_TYPE", default="azure")
  70. )
  71. """Legacy, for `openai<1.0.0` support."""
  72. validate_base_url: bool = True
  73. """For backwards compatibility. If legacy val openai_api_base is passed in, try to
  74. infer if it is a base_url or azure_endpoint and update accordingly.
  75. """
  76. @classmethod
  77. def get_lc_namespace(cls) -> list[str]:
  78. """Get the namespace of the LangChain object.
  79. Returns:
  80. `["langchain", "llms", "openai"]`
  81. """
  82. return ["langchain", "llms", "openai"]
  83. @property
  84. def lc_secrets(self) -> dict[str, str]:
  85. """Mapping of secret keys to environment variables."""
  86. return {
  87. "openai_api_key": "AZURE_OPENAI_API_KEY",
  88. "azure_ad_token": "AZURE_OPENAI_AD_TOKEN",
  89. }
  90. @classmethod
  91. def is_lc_serializable(cls) -> bool:
  92. """Return whether this model can be serialized by LangChain."""
  93. return True
  94. @model_validator(mode="after")
  95. def validate_environment(self) -> Self:
  96. """Validate that api key and python package exists in environment."""
  97. if self.n < 1:
  98. msg = "n must be at least 1."
  99. raise ValueError(msg)
  100. if self.streaming and self.n > 1:
  101. msg = "Cannot stream results when n > 1."
  102. raise ValueError(msg)
  103. if self.streaming and self.best_of > 1:
  104. msg = "Cannot stream results when best_of > 1."
  105. raise ValueError(msg)
  106. # For backwards compatibility. Before openai v1, no distinction was made
  107. # between azure_endpoint and base_url (openai_api_base).
  108. openai_api_base = self.openai_api_base
  109. if openai_api_base and self.validate_base_url:
  110. if "/openai" not in openai_api_base:
  111. self.openai_api_base = (
  112. cast(str, self.openai_api_base).rstrip("/") + "/openai"
  113. )
  114. msg = (
  115. "As of openai>=1.0.0, Azure endpoints should be specified via "
  116. "the `azure_endpoint` param not `openai_api_base` "
  117. "(or alias `base_url`)."
  118. )
  119. raise ValueError(msg)
  120. if self.deployment_name:
  121. msg = (
  122. "As of openai>=1.0.0, if `deployment_name` (or alias "
  123. "`azure_deployment`) is specified then "
  124. "`openai_api_base` (or alias `base_url`) should not be. "
  125. "Instead use `deployment_name` (or alias `azure_deployment`) "
  126. "and `azure_endpoint`."
  127. )
  128. raise ValueError(msg)
  129. self.deployment_name = None
  130. client_params: dict = {
  131. "api_version": self.openai_api_version,
  132. "azure_endpoint": self.azure_endpoint,
  133. "azure_deployment": self.deployment_name,
  134. "api_key": self.openai_api_key.get_secret_value()
  135. if self.openai_api_key
  136. else None,
  137. "azure_ad_token": self.azure_ad_token.get_secret_value()
  138. if self.azure_ad_token
  139. else None,
  140. "azure_ad_token_provider": self.azure_ad_token_provider,
  141. "organization": self.openai_organization,
  142. "base_url": self.openai_api_base,
  143. "timeout": self.request_timeout,
  144. "max_retries": self.max_retries,
  145. "default_headers": {
  146. **(self.default_headers or {}),
  147. "User-Agent": "langchain-partner-python-azure-openai",
  148. },
  149. "default_query": self.default_query,
  150. }
  151. if not self.client:
  152. sync_specific = {"http_client": self.http_client}
  153. self.client = openai.AzureOpenAI(
  154. **client_params,
  155. **sync_specific, # type: ignore[arg-type]
  156. ).completions
  157. if not self.async_client:
  158. async_specific = {"http_client": self.http_async_client}
  159. if self.azure_ad_async_token_provider:
  160. client_params["azure_ad_token_provider"] = (
  161. self.azure_ad_async_token_provider
  162. )
  163. self.async_client = openai.AsyncAzureOpenAI(
  164. **client_params,
  165. **async_specific, # type: ignore[arg-type]
  166. ).completions
  167. return self
  168. @property
  169. def _identifying_params(self) -> Mapping[str, Any]:
  170. return {
  171. "deployment_name": self.deployment_name,
  172. **super()._identifying_params,
  173. }
  174. @property
  175. def _invocation_params(self) -> dict[str, Any]:
  176. openai_params = {"model": self.deployment_name}
  177. return {**openai_params, **super()._invocation_params}
  178. def _get_ls_params(
  179. self, stop: list[str] | None = None, **kwargs: Any
  180. ) -> LangSmithParams:
  181. """Get standard params for tracing."""
  182. params = super()._get_ls_params(stop=stop, **kwargs)
  183. invocation_params = self._invocation_params
  184. params["ls_provider"] = "azure"
  185. if model_name := invocation_params.get("model"):
  186. params["ls_model_name"] = model_name
  187. return params
  188. @property
  189. def _llm_type(self) -> str:
  190. """Return type of llm."""
  191. return "azure"
  192. @property
  193. def lc_attributes(self) -> dict[str, Any]:
  194. """Attributes relevant to tracing."""
  195. return {
  196. "openai_api_type": self.openai_api_type,
  197. "openai_api_version": self.openai_api_version,
  198. }