__init__.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
  2. from __future__ import annotations
  3. import os as _os
  4. import typing as _t
  5. from typing_extensions import override
  6. from . import types
  7. from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes, omit, not_given
  8. from ._utils import file_from_path
  9. from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions
  10. from ._models import BaseModel
  11. from ._version import __title__, __version__
  12. from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
  13. from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS
  14. from ._exceptions import (
  15. APIError,
  16. OpenAIError,
  17. ConflictError,
  18. NotFoundError,
  19. APIStatusError,
  20. RateLimitError,
  21. APITimeoutError,
  22. BadRequestError,
  23. APIConnectionError,
  24. AuthenticationError,
  25. InternalServerError,
  26. PermissionDeniedError,
  27. LengthFinishReasonError,
  28. UnprocessableEntityError,
  29. APIResponseValidationError,
  30. InvalidWebhookSignatureError,
  31. ContentFilterFinishReasonError,
  32. )
  33. from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
  34. from ._utils._logs import setup_logging as _setup_logging
  35. from ._legacy_response import HttpxBinaryResponseContent as HttpxBinaryResponseContent
  36. __all__ = [
  37. "types",
  38. "__version__",
  39. "__title__",
  40. "NoneType",
  41. "Transport",
  42. "ProxiesTypes",
  43. "NotGiven",
  44. "NOT_GIVEN",
  45. "not_given",
  46. "Omit",
  47. "omit",
  48. "OpenAIError",
  49. "APIError",
  50. "APIStatusError",
  51. "APITimeoutError",
  52. "APIConnectionError",
  53. "APIResponseValidationError",
  54. "BadRequestError",
  55. "AuthenticationError",
  56. "PermissionDeniedError",
  57. "NotFoundError",
  58. "ConflictError",
  59. "UnprocessableEntityError",
  60. "RateLimitError",
  61. "InternalServerError",
  62. "LengthFinishReasonError",
  63. "ContentFilterFinishReasonError",
  64. "InvalidWebhookSignatureError",
  65. "Timeout",
  66. "RequestOptions",
  67. "Client",
  68. "AsyncClient",
  69. "Stream",
  70. "AsyncStream",
  71. "OpenAI",
  72. "AsyncOpenAI",
  73. "file_from_path",
  74. "BaseModel",
  75. "DEFAULT_TIMEOUT",
  76. "DEFAULT_MAX_RETRIES",
  77. "DEFAULT_CONNECTION_LIMITS",
  78. "DefaultHttpxClient",
  79. "DefaultAsyncHttpxClient",
  80. "DefaultAioHttpClient",
  81. ]
  82. if not _t.TYPE_CHECKING:
  83. from ._utils._resources_proxy import resources as resources
  84. from .lib import azure as _azure, pydantic_function_tool as pydantic_function_tool
  85. from .version import VERSION as VERSION
  86. from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI
  87. from .lib._old_api import *
  88. from .lib.streaming import (
  89. AssistantEventHandler as AssistantEventHandler,
  90. AsyncAssistantEventHandler as AsyncAssistantEventHandler,
  91. )
  92. _setup_logging()
  93. # Update the __module__ attribute for exported symbols so that
  94. # error messages point to this module instead of the module
  95. # it was originally defined in, e.g.
  96. # openai._exceptions.NotFoundError -> openai.NotFoundError
  97. __locals = locals()
  98. for __name in __all__:
  99. if not __name.startswith("__"):
  100. try:
  101. __locals[__name].__module__ = "openai"
  102. except (TypeError, AttributeError):
  103. # Some of our exported symbols are builtins which we can't set attributes for.
  104. pass
  105. # ------ Module level client ------
  106. import typing as _t
  107. import typing_extensions as _te
  108. import httpx as _httpx
  109. from ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES
  110. api_key: str | None = None
  111. organization: str | None = None
  112. project: str | None = None
  113. webhook_secret: str | None = None
  114. base_url: str | _httpx.URL | None = None
  115. timeout: float | Timeout | None = DEFAULT_TIMEOUT
  116. max_retries: int = DEFAULT_MAX_RETRIES
  117. default_headers: _t.Mapping[str, str] | None = None
  118. default_query: _t.Mapping[str, object] | None = None
  119. http_client: _httpx.Client | None = None
  120. _ApiType = _te.Literal["openai", "azure"]
  121. api_type: _ApiType | None = _t.cast(_ApiType, _os.environ.get("OPENAI_API_TYPE"))
  122. api_version: str | None = _os.environ.get("OPENAI_API_VERSION")
  123. azure_endpoint: str | None = _os.environ.get("AZURE_OPENAI_ENDPOINT")
  124. azure_ad_token: str | None = _os.environ.get("AZURE_OPENAI_AD_TOKEN")
  125. azure_ad_token_provider: _azure.AzureADTokenProvider | None = None
  126. class _ModuleClient(OpenAI):
  127. # Note: we have to use type: ignores here as overriding class members
  128. # with properties is technically unsafe but it is fine for our use case
  129. @property # type: ignore
  130. @override
  131. def api_key(self) -> str | None:
  132. return api_key
  133. @api_key.setter # type: ignore
  134. def api_key(self, value: str | None) -> None: # type: ignore
  135. global api_key
  136. api_key = value
  137. @property # type: ignore
  138. @override
  139. def organization(self) -> str | None:
  140. return organization
  141. @organization.setter # type: ignore
  142. def organization(self, value: str | None) -> None: # type: ignore
  143. global organization
  144. organization = value
  145. @property # type: ignore
  146. @override
  147. def project(self) -> str | None:
  148. return project
  149. @project.setter # type: ignore
  150. def project(self, value: str | None) -> None: # type: ignore
  151. global project
  152. project = value
  153. @property # type: ignore
  154. @override
  155. def webhook_secret(self) -> str | None:
  156. return webhook_secret
  157. @webhook_secret.setter # type: ignore
  158. def webhook_secret(self, value: str | None) -> None: # type: ignore
  159. global webhook_secret
  160. webhook_secret = value
  161. @property
  162. @override
  163. def base_url(self) -> _httpx.URL:
  164. if base_url is not None:
  165. return _httpx.URL(base_url)
  166. return super().base_url
  167. @base_url.setter
  168. def base_url(self, url: _httpx.URL | str) -> None:
  169. super().base_url = url # type: ignore[misc]
  170. @property # type: ignore
  171. @override
  172. def timeout(self) -> float | Timeout | None:
  173. return timeout
  174. @timeout.setter # type: ignore
  175. def timeout(self, value: float | Timeout | None) -> None: # type: ignore
  176. global timeout
  177. timeout = value
  178. @property # type: ignore
  179. @override
  180. def max_retries(self) -> int:
  181. return max_retries
  182. @max_retries.setter # type: ignore
  183. def max_retries(self, value: int) -> None: # type: ignore
  184. global max_retries
  185. max_retries = value
  186. @property # type: ignore
  187. @override
  188. def _custom_headers(self) -> _t.Mapping[str, str] | None:
  189. return default_headers
  190. @_custom_headers.setter # type: ignore
  191. def _custom_headers(self, value: _t.Mapping[str, str] | None) -> None: # type: ignore
  192. global default_headers
  193. default_headers = value
  194. @property # type: ignore
  195. @override
  196. def _custom_query(self) -> _t.Mapping[str, object] | None:
  197. return default_query
  198. @_custom_query.setter # type: ignore
  199. def _custom_query(self, value: _t.Mapping[str, object] | None) -> None: # type: ignore
  200. global default_query
  201. default_query = value
  202. @property # type: ignore
  203. @override
  204. def _client(self) -> _httpx.Client:
  205. return http_client or super()._client
  206. @_client.setter # type: ignore
  207. def _client(self, value: _httpx.Client) -> None: # type: ignore
  208. global http_client
  209. http_client = value
  210. class _AzureModuleClient(_ModuleClient, AzureOpenAI): # type: ignore
  211. ...
  212. class _AmbiguousModuleClientUsageError(OpenAIError):
  213. def __init__(self) -> None:
  214. super().__init__(
  215. "Ambiguous use of module client; please set `openai.api_type` or the `OPENAI_API_TYPE` environment variable to `openai` or `azure`"
  216. )
  217. def _has_openai_credentials() -> bool:
  218. return _os.environ.get("OPENAI_API_KEY") is not None
  219. def _has_azure_credentials() -> bool:
  220. return azure_endpoint is not None or _os.environ.get("AZURE_OPENAI_API_KEY") is not None
  221. def _has_azure_ad_credentials() -> bool:
  222. return (
  223. _os.environ.get("AZURE_OPENAI_AD_TOKEN") is not None
  224. or azure_ad_token is not None
  225. or azure_ad_token_provider is not None
  226. )
  227. _client: OpenAI | None = None
  228. def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction]
  229. global _client
  230. if _client is None:
  231. global api_type, azure_endpoint, azure_ad_token, api_version
  232. if azure_endpoint is None:
  233. azure_endpoint = _os.environ.get("AZURE_OPENAI_ENDPOINT")
  234. if azure_ad_token is None:
  235. azure_ad_token = _os.environ.get("AZURE_OPENAI_AD_TOKEN")
  236. if api_version is None:
  237. api_version = _os.environ.get("OPENAI_API_VERSION")
  238. if api_type is None:
  239. has_openai = _has_openai_credentials()
  240. has_azure = _has_azure_credentials()
  241. has_azure_ad = _has_azure_ad_credentials()
  242. if has_openai and (has_azure or has_azure_ad):
  243. raise _AmbiguousModuleClientUsageError()
  244. if (azure_ad_token is not None or azure_ad_token_provider is not None) and _os.environ.get(
  245. "AZURE_OPENAI_API_KEY"
  246. ) is not None:
  247. raise _AmbiguousModuleClientUsageError()
  248. if has_azure or has_azure_ad:
  249. api_type = "azure"
  250. else:
  251. api_type = "openai"
  252. if api_type == "azure":
  253. _client = _AzureModuleClient( # type: ignore
  254. api_version=api_version,
  255. azure_endpoint=azure_endpoint,
  256. api_key=api_key,
  257. azure_ad_token=azure_ad_token,
  258. azure_ad_token_provider=azure_ad_token_provider,
  259. organization=organization,
  260. base_url=base_url,
  261. timeout=timeout,
  262. max_retries=max_retries,
  263. default_headers=default_headers,
  264. default_query=default_query,
  265. http_client=http_client,
  266. )
  267. return _client
  268. _client = _ModuleClient(
  269. api_key=api_key,
  270. organization=organization,
  271. project=project,
  272. webhook_secret=webhook_secret,
  273. base_url=base_url,
  274. timeout=timeout,
  275. max_retries=max_retries,
  276. default_headers=default_headers,
  277. default_query=default_query,
  278. http_client=http_client,
  279. )
  280. return _client
  281. return _client
  282. def _reset_client() -> None: # type: ignore[reportUnusedFunction]
  283. global _client
  284. _client = None
  285. from ._module_client import (
  286. beta as beta,
  287. chat as chat,
  288. audio as audio,
  289. evals as evals,
  290. files as files,
  291. images as images,
  292. models as models,
  293. videos as videos,
  294. batches as batches,
  295. uploads as uploads,
  296. realtime as realtime,
  297. webhooks as webhooks,
  298. responses as responses,
  299. containers as containers,
  300. embeddings as embeddings,
  301. completions as completions,
  302. fine_tuning as fine_tuning,
  303. moderations as moderations,
  304. conversations as conversations,
  305. vector_stores as vector_stores,
  306. )