base.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872
  1. """Base classes for OpenAI large language models. Chat models are in `chat_models/`."""
  2. from __future__ import annotations
  3. import logging
  4. import sys
  5. from collections.abc import AsyncIterator, Callable, Collection, Iterator, Mapping
  6. from typing import Any, Literal
  7. import openai
  8. import tiktoken
  9. from langchain_core.callbacks import (
  10. AsyncCallbackManagerForLLMRun,
  11. CallbackManagerForLLMRun,
  12. )
  13. from langchain_core.language_models.llms import BaseLLM
  14. from langchain_core.outputs import Generation, GenerationChunk, LLMResult
  15. from langchain_core.utils import get_pydantic_field_names
  16. from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
  17. from pydantic import ConfigDict, Field, SecretStr, model_validator
  18. from typing_extensions import Self
  19. logger = logging.getLogger(__name__)
  20. def _update_token_usage(
  21. keys: set[str], response: dict[str, Any], token_usage: dict[str, Any]
  22. ) -> None:
  23. """Update token usage."""
  24. _keys_to_use = keys.intersection(response["usage"])
  25. for _key in _keys_to_use:
  26. if _key not in token_usage:
  27. token_usage[_key] = response["usage"][_key]
  28. else:
  29. token_usage[_key] += response["usage"][_key]
  30. def _stream_response_to_generation_chunk(
  31. stream_response: dict[str, Any],
  32. ) -> GenerationChunk:
  33. """Convert a stream response to a generation chunk."""
  34. if not stream_response["choices"]:
  35. return GenerationChunk(text="")
  36. return GenerationChunk(
  37. text=stream_response["choices"][0]["text"] or "",
  38. generation_info={
  39. "finish_reason": stream_response["choices"][0].get("finish_reason", None),
  40. "logprobs": stream_response["choices"][0].get("logprobs", None),
  41. },
  42. )
  43. class BaseOpenAI(BaseLLM):
  44. """Base OpenAI large language model class.
  45. Setup:
  46. Install `langchain-openai` and set environment variable `OPENAI_API_KEY`.
  47. ```bash
  48. pip install -U langchain-openai
  49. export OPENAI_API_KEY="your-api-key"
  50. ```
  51. Key init args — completion params:
  52. model_name:
  53. Name of OpenAI model to use.
  54. temperature:
  55. Sampling temperature.
  56. max_tokens:
  57. Max number of tokens to generate.
  58. top_p:
  59. Total probability mass of tokens to consider at each step.
  60. frequency_penalty:
  61. Penalizes repeated tokens according to frequency.
  62. presence_penalty:
  63. Penalizes repeated tokens.
  64. n:
  65. How many completions to generate for each prompt.
  66. best_of:
  67. Generates best_of completions server-side and returns the "best".
  68. logit_bias:
  69. Adjust the probability of specific tokens being generated.
  70. seed:
  71. Seed for generation.
  72. logprobs:
  73. Include the log probabilities on the logprobs most likely output tokens.
  74. streaming:
  75. Whether to stream the results or not.
  76. Key init args — client params:
  77. openai_api_key:
  78. OpenAI API key. If not passed in will be read from env var
  79. `OPENAI_API_KEY`.
  80. openai_api_base:
  81. Base URL path for API requests, leave blank if not using a proxy or
  82. service emulator.
  83. openai_organization:
  84. OpenAI organization ID. If not passed in will be read from env
  85. var `OPENAI_ORG_ID`.
  86. request_timeout:
  87. Timeout for requests to OpenAI completion API.
  88. max_retries:
  89. Maximum number of retries to make when generating.
  90. batch_size:
  91. Batch size to use when passing multiple documents to generate.
  92. See full list of supported init args and their descriptions in the params section.
  93. Instantiate:
  94. ```python
  95. from langchain_openai.llms.base import BaseOpenAI
  96. model = BaseOpenAI(
  97. model_name="gpt-3.5-turbo-instruct",
  98. temperature=0.7,
  99. max_tokens=256,
  100. top_p=1,
  101. frequency_penalty=0,
  102. presence_penalty=0,
  103. # openai_api_key="...",
  104. # openai_api_base="...",
  105. # openai_organization="...",
  106. # other params...
  107. )
  108. ```
  109. Invoke:
  110. ```python
  111. input_text = "The meaning of life is "
  112. response = model.invoke(input_text)
  113. print(response)
  114. ```
  115. ```txt
  116. "a philosophical question that has been debated by thinkers and
  117. scholars for centuries."
  118. ```
  119. Stream:
  120. ```python
  121. for chunk in model.stream(input_text):
  122. print(chunk, end="")
  123. ```
  124. ```txt
  125. a philosophical question that has been debated by thinkers and
  126. scholars for centuries.
  127. ```
  128. Async:
  129. ```python
  130. response = await model.ainvoke(input_text)
  131. # stream:
  132. # async for chunk in model.astream(input_text):
  133. # print(chunk, end="")
  134. # batch:
  135. # await model.abatch([input_text])
  136. ```
  137. ```
  138. "a philosophical question that has been debated by thinkers and
  139. scholars for centuries."
  140. ```
  141. """
  142. client: Any = Field(default=None, exclude=True)
  143. async_client: Any = Field(default=None, exclude=True)
  144. model_name: str = Field(default="gpt-3.5-turbo-instruct", alias="model")
  145. """Model name to use."""
  146. temperature: float = 0.7
  147. """What sampling temperature to use."""
  148. max_tokens: int = 256
  149. """The maximum number of tokens to generate in the completion.
  150. -1 returns as many tokens as possible given the prompt and
  151. the models maximal context size."""
  152. top_p: float = 1
  153. """Total probability mass of tokens to consider at each step."""
  154. frequency_penalty: float = 0
  155. """Penalizes repeated tokens according to frequency."""
  156. presence_penalty: float = 0
  157. """Penalizes repeated tokens."""
  158. n: int = 1
  159. """How many completions to generate for each prompt."""
  160. best_of: int = 1
  161. """Generates best_of completions server-side and returns the "best"."""
  162. model_kwargs: dict[str, Any] = Field(default_factory=dict)
  163. """Holds any model parameters valid for `create` call not explicitly specified."""
  164. openai_api_key: SecretStr | None | Callable[[], str] = Field(
  165. alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
  166. )
  167. """Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
  168. openai_api_base: str | None = Field(
  169. alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None)
  170. )
  171. """Base URL path for API requests, leave blank if not using a proxy or service
  172. emulator."""
  173. openai_organization: str | None = Field(
  174. alias="organization",
  175. default_factory=from_env(
  176. ["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None
  177. ),
  178. )
  179. """Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
  180. # to support explicit proxy for OpenAI
  181. openai_proxy: str | None = Field(
  182. default_factory=from_env("OPENAI_PROXY", default=None)
  183. )
  184. batch_size: int = 20
  185. """Batch size to use when passing multiple documents to generate."""
  186. request_timeout: float | tuple[float, float] | Any | None = Field(
  187. default=None, alias="timeout"
  188. )
  189. """Timeout for requests to OpenAI completion API. Can be float, `httpx.Timeout` or
  190. None."""
  191. logit_bias: dict[str, float] | None = None
  192. """Adjust the probability of specific tokens being generated."""
  193. max_retries: int = 2
  194. """Maximum number of retries to make when generating."""
  195. seed: int | None = None
  196. """Seed for generation"""
  197. logprobs: int | None = None
  198. """Include the log probabilities on the logprobs most likely output tokens,
  199. as well the chosen tokens."""
  200. streaming: bool = False
  201. """Whether to stream the results or not."""
  202. allowed_special: Literal["all"] | set[str] = set()
  203. """Set of special tokens that are allowed。"""
  204. disallowed_special: Literal["all"] | Collection[str] = "all"
  205. """Set of special tokens that are not allowed。"""
  206. tiktoken_model_name: str | None = None
  207. """The model name to pass to tiktoken when using this class.
  208. Tiktoken is used to count the number of tokens in documents to constrain
  209. them to be under a certain limit.
  210. By default, when set to `None`, this will be the same as the embedding model name.
  211. However, there are some cases where you may want to use this `Embedding` class with
  212. a model name not supported by tiktoken. This can include when using Azure embeddings
  213. or when using one of the many model providers that expose an OpenAI-like
  214. API but with different models. In those cases, in order to avoid erroring
  215. when tiktoken is called, you can specify a model name to use here.
  216. """
  217. default_headers: Mapping[str, str] | None = None
  218. default_query: Mapping[str, object] | None = None
  219. # Configure a custom httpx client. See the
  220. # [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
  221. http_client: Any | None = None
  222. """Optional `httpx.Client`.
  223. Only used for sync invocations. Must specify `http_async_client` as well if you'd
  224. like a custom client for async invocations.
  225. """
  226. http_async_client: Any | None = None
  227. """Optional `httpx.AsyncClient`.
  228. Only used for async invocations. Must specify `http_client` as well if you'd like a
  229. custom client for sync invocations.
  230. """
  231. extra_body: Mapping[str, Any] | None = None
  232. """Optional additional JSON properties to include in the request parameters when
  233. making requests to OpenAI compatible APIs, such as vLLM."""
  234. model_config = ConfigDict(populate_by_name=True)
  235. @model_validator(mode="before")
  236. @classmethod
  237. def build_extra(cls, values: dict[str, Any]) -> Any:
  238. """Build extra kwargs from additional params that were passed in."""
  239. all_required_field_names = get_pydantic_field_names(cls)
  240. return _build_model_kwargs(values, all_required_field_names)
  241. @model_validator(mode="after")
  242. def validate_environment(self) -> Self:
  243. """Validate that api key and python package exists in environment."""
  244. if self.n < 1:
  245. msg = "n must be at least 1."
  246. raise ValueError(msg)
  247. if self.streaming and self.n > 1:
  248. msg = "Cannot stream results when n > 1."
  249. raise ValueError(msg)
  250. if self.streaming and self.best_of > 1:
  251. msg = "Cannot stream results when best_of > 1."
  252. raise ValueError(msg)
  253. # Resolve API key from SecretStr or Callable
  254. api_key_value: str | Callable[[], str] | None = None
  255. if self.openai_api_key is not None:
  256. if isinstance(self.openai_api_key, SecretStr):
  257. api_key_value = self.openai_api_key.get_secret_value()
  258. elif callable(self.openai_api_key):
  259. api_key_value = self.openai_api_key
  260. client_params: dict = {
  261. "api_key": api_key_value,
  262. "organization": self.openai_organization,
  263. "base_url": self.openai_api_base,
  264. "timeout": self.request_timeout,
  265. "max_retries": self.max_retries,
  266. "default_headers": self.default_headers,
  267. "default_query": self.default_query,
  268. }
  269. if not self.client:
  270. sync_specific = {"http_client": self.http_client}
  271. self.client = openai.OpenAI(**client_params, **sync_specific).completions # type: ignore[arg-type]
  272. if not self.async_client:
  273. async_specific = {"http_client": self.http_async_client}
  274. self.async_client = openai.AsyncOpenAI(
  275. **client_params,
  276. **async_specific, # type: ignore[arg-type]
  277. ).completions
  278. return self
  279. @property
  280. def _default_params(self) -> dict[str, Any]:
  281. """Get the default parameters for calling OpenAI API."""
  282. normal_params: dict[str, Any] = {
  283. "temperature": self.temperature,
  284. "top_p": self.top_p,
  285. "frequency_penalty": self.frequency_penalty,
  286. "presence_penalty": self.presence_penalty,
  287. "n": self.n,
  288. "seed": self.seed,
  289. "logprobs": self.logprobs,
  290. }
  291. if self.logit_bias is not None:
  292. normal_params["logit_bias"] = self.logit_bias
  293. if self.max_tokens is not None:
  294. normal_params["max_tokens"] = self.max_tokens
  295. if self.extra_body is not None:
  296. normal_params["extra_body"] = self.extra_body
  297. # Azure gpt-35-turbo doesn't support best_of
  298. # don't specify best_of if it is 1
  299. if self.best_of > 1:
  300. normal_params["best_of"] = self.best_of
  301. return {**normal_params, **self.model_kwargs}
  302. def _stream(
  303. self,
  304. prompt: str,
  305. stop: list[str] | None = None,
  306. run_manager: CallbackManagerForLLMRun | None = None,
  307. **kwargs: Any,
  308. ) -> Iterator[GenerationChunk]:
  309. params = {**self._invocation_params, **kwargs, "stream": True}
  310. self.get_sub_prompts(params, [prompt], stop) # this mutates params
  311. for stream_resp in self.client.create(prompt=prompt, **params):
  312. if not isinstance(stream_resp, dict):
  313. stream_resp = stream_resp.model_dump()
  314. chunk = _stream_response_to_generation_chunk(stream_resp)
  315. if run_manager:
  316. run_manager.on_llm_new_token(
  317. chunk.text,
  318. chunk=chunk,
  319. verbose=self.verbose,
  320. logprobs=(
  321. chunk.generation_info["logprobs"]
  322. if chunk.generation_info
  323. else None
  324. ),
  325. )
  326. yield chunk
  327. async def _astream(
  328. self,
  329. prompt: str,
  330. stop: list[str] | None = None,
  331. run_manager: AsyncCallbackManagerForLLMRun | None = None,
  332. **kwargs: Any,
  333. ) -> AsyncIterator[GenerationChunk]:
  334. params = {**self._invocation_params, **kwargs, "stream": True}
  335. self.get_sub_prompts(params, [prompt], stop) # this mutates params
  336. async for stream_resp in await self.async_client.create(
  337. prompt=prompt, **params
  338. ):
  339. if not isinstance(stream_resp, dict):
  340. stream_resp = stream_resp.model_dump()
  341. chunk = _stream_response_to_generation_chunk(stream_resp)
  342. if run_manager:
  343. await run_manager.on_llm_new_token(
  344. chunk.text,
  345. chunk=chunk,
  346. verbose=self.verbose,
  347. logprobs=(
  348. chunk.generation_info["logprobs"]
  349. if chunk.generation_info
  350. else None
  351. ),
  352. )
  353. yield chunk
  354. def _generate(
  355. self,
  356. prompts: list[str],
  357. stop: list[str] | None = None,
  358. run_manager: CallbackManagerForLLMRun | None = None,
  359. **kwargs: Any,
  360. ) -> LLMResult:
  361. """Call out to OpenAI's endpoint with k unique prompts.
  362. Args:
  363. prompts: The prompts to pass into the model.
  364. stop: Optional list of stop words to use when generating.
  365. run_manager: Optional callback manager to use for the call.
  366. Returns:
  367. The full LLM output.
  368. Example:
  369. ```python
  370. response = openai.generate(["Tell me a joke."])
  371. ```
  372. """
  373. # TODO: write a unit test for this
  374. params = self._invocation_params
  375. params = {**params, **kwargs}
  376. sub_prompts = self.get_sub_prompts(params, prompts, stop)
  377. choices = []
  378. token_usage: dict[str, int] = {}
  379. # Get the token usage from the response.
  380. # Includes prompt, completion, and total tokens used.
  381. _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
  382. system_fingerprint: str | None = None
  383. for _prompts in sub_prompts:
  384. if self.streaming:
  385. if len(_prompts) > 1:
  386. msg = "Cannot stream results with multiple prompts."
  387. raise ValueError(msg)
  388. generation: GenerationChunk | None = None
  389. for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs):
  390. if generation is None:
  391. generation = chunk
  392. else:
  393. generation += chunk
  394. if generation is None:
  395. msg = "Generation is empty after streaming."
  396. raise ValueError(msg)
  397. choices.append(
  398. {
  399. "text": generation.text,
  400. "finish_reason": (
  401. generation.generation_info.get("finish_reason")
  402. if generation.generation_info
  403. else None
  404. ),
  405. "logprobs": (
  406. generation.generation_info.get("logprobs")
  407. if generation.generation_info
  408. else None
  409. ),
  410. }
  411. )
  412. else:
  413. response = self.client.create(prompt=_prompts, **params)
  414. if not isinstance(response, dict):
  415. # V1 client returns the response in an PyDantic object instead of
  416. # dict. For the transition period, we deep convert it to dict.
  417. response = response.model_dump()
  418. # Sometimes the AI Model calling will get error, we should raise it.
  419. # Otherwise, the next code 'choices.extend(response["choices"])'
  420. # will throw a "TypeError: 'NoneType' object is not iterable" error
  421. # to mask the true error. Because 'response["choices"]' is None.
  422. if response.get("error"):
  423. raise ValueError(response.get("error"))
  424. choices.extend(response["choices"])
  425. _update_token_usage(_keys, response, token_usage)
  426. if not system_fingerprint:
  427. system_fingerprint = response.get("system_fingerprint")
  428. return self.create_llm_result(
  429. choices, prompts, params, token_usage, system_fingerprint=system_fingerprint
  430. )
  431. async def _agenerate(
  432. self,
  433. prompts: list[str],
  434. stop: list[str] | None = None,
  435. run_manager: AsyncCallbackManagerForLLMRun | None = None,
  436. **kwargs: Any,
  437. ) -> LLMResult:
  438. """Call out to OpenAI's endpoint async with k unique prompts."""
  439. params = self._invocation_params
  440. params = {**params, **kwargs}
  441. sub_prompts = self.get_sub_prompts(params, prompts, stop)
  442. choices = []
  443. token_usage: dict[str, int] = {}
  444. # Get the token usage from the response.
  445. # Includes prompt, completion, and total tokens used.
  446. _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
  447. system_fingerprint: str | None = None
  448. for _prompts in sub_prompts:
  449. if self.streaming:
  450. if len(_prompts) > 1:
  451. msg = "Cannot stream results with multiple prompts."
  452. raise ValueError(msg)
  453. generation: GenerationChunk | None = None
  454. async for chunk in self._astream(
  455. _prompts[0], stop, run_manager, **kwargs
  456. ):
  457. if generation is None:
  458. generation = chunk
  459. else:
  460. generation += chunk
  461. if generation is None:
  462. msg = "Generation is empty after streaming."
  463. raise ValueError(msg)
  464. choices.append(
  465. {
  466. "text": generation.text,
  467. "finish_reason": (
  468. generation.generation_info.get("finish_reason")
  469. if generation.generation_info
  470. else None
  471. ),
  472. "logprobs": (
  473. generation.generation_info.get("logprobs")
  474. if generation.generation_info
  475. else None
  476. ),
  477. }
  478. )
  479. else:
  480. response = await self.async_client.create(prompt=_prompts, **params)
  481. if not isinstance(response, dict):
  482. response = response.model_dump()
  483. choices.extend(response["choices"])
  484. _update_token_usage(_keys, response, token_usage)
  485. return self.create_llm_result(
  486. choices, prompts, params, token_usage, system_fingerprint=system_fingerprint
  487. )
  488. def get_sub_prompts(
  489. self,
  490. params: dict[str, Any],
  491. prompts: list[str],
  492. stop: list[str] | None = None,
  493. ) -> list[list[str]]:
  494. """Get the sub prompts for llm call."""
  495. if stop is not None:
  496. params["stop"] = stop
  497. if params["max_tokens"] == -1:
  498. if len(prompts) != 1:
  499. msg = "max_tokens set to -1 not supported for multiple inputs."
  500. raise ValueError(msg)
  501. params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
  502. return [
  503. prompts[i : i + self.batch_size]
  504. for i in range(0, len(prompts), self.batch_size)
  505. ]
  506. def create_llm_result(
  507. self,
  508. choices: Any,
  509. prompts: list[str],
  510. params: dict[str, Any],
  511. token_usage: dict[str, int],
  512. *,
  513. system_fingerprint: str | None = None,
  514. ) -> LLMResult:
  515. """Create the LLMResult from the choices and prompts."""
  516. generations = []
  517. n = params.get("n", self.n)
  518. for i, _ in enumerate(prompts):
  519. sub_choices = choices[i * n : (i + 1) * n]
  520. generations.append(
  521. [
  522. Generation(
  523. text=choice["text"],
  524. generation_info={
  525. "finish_reason": choice.get("finish_reason"),
  526. "logprobs": choice.get("logprobs"),
  527. },
  528. )
  529. for choice in sub_choices
  530. ]
  531. )
  532. llm_output = {"token_usage": token_usage, "model_name": self.model_name}
  533. if system_fingerprint:
  534. llm_output["system_fingerprint"] = system_fingerprint
  535. return LLMResult(generations=generations, llm_output=llm_output)
  536. @property
  537. def _invocation_params(self) -> dict[str, Any]:
  538. """Get the parameters used to invoke the model."""
  539. return self._default_params
  540. @property
  541. def _identifying_params(self) -> Mapping[str, Any]:
  542. """Get the identifying parameters."""
  543. return {"model_name": self.model_name, **self._default_params}
  544. @property
  545. def _llm_type(self) -> str:
  546. """Return type of llm."""
  547. return "openai"
  548. def get_token_ids(self, text: str) -> list[int]:
  549. """Get the token IDs using the tiktoken package."""
  550. if self.custom_get_token_ids is not None:
  551. return self.custom_get_token_ids(text)
  552. # tiktoken NOT supported for Python < 3.8
  553. if sys.version_info[1] < 8:
  554. return super().get_num_tokens(text)
  555. model_name = self.tiktoken_model_name or self.model_name
  556. try:
  557. enc = tiktoken.encoding_for_model(model_name)
  558. except KeyError:
  559. enc = tiktoken.get_encoding("cl100k_base")
  560. return enc.encode(
  561. text,
  562. allowed_special=self.allowed_special,
  563. disallowed_special=self.disallowed_special,
  564. )
  565. @staticmethod
  566. def modelname_to_contextsize(modelname: str) -> int:
  567. """Calculate the maximum number of tokens possible to generate for a model.
  568. Args:
  569. modelname: The modelname we want to know the context size for.
  570. Returns:
  571. The maximum context size
  572. Example:
  573. ```python
  574. max_tokens = openai.modelname_to_contextsize("gpt-3.5-turbo-instruct")
  575. ```
  576. """
  577. model_token_mapping = {
  578. "gpt-4o-mini": 128_000,
  579. "gpt-4o": 128_000,
  580. "gpt-4o-2024-05-13": 128_000,
  581. "gpt-4": 8192,
  582. "gpt-4-0314": 8192,
  583. "gpt-4-0613": 8192,
  584. "gpt-4-32k": 32768,
  585. "gpt-4-32k-0314": 32768,
  586. "gpt-4-32k-0613": 32768,
  587. "gpt-3.5-turbo": 4096,
  588. "gpt-3.5-turbo-0301": 4096,
  589. "gpt-3.5-turbo-0613": 4096,
  590. "gpt-3.5-turbo-16k": 16385,
  591. "gpt-3.5-turbo-16k-0613": 16385,
  592. "gpt-3.5-turbo-instruct": 4096,
  593. "text-ada-001": 2049,
  594. "ada": 2049,
  595. "text-babbage-001": 2040,
  596. "babbage": 2049,
  597. "text-curie-001": 2049,
  598. "curie": 2049,
  599. "davinci": 2049,
  600. "text-davinci-003": 4097,
  601. "text-davinci-002": 4097,
  602. "code-davinci-002": 8001,
  603. "code-davinci-001": 8001,
  604. "code-cushman-002": 2048,
  605. "code-cushman-001": 2048,
  606. }
  607. # handling finetuned models
  608. if "ft-" in modelname:
  609. modelname = modelname.split(":")[0]
  610. context_size = model_token_mapping.get(modelname)
  611. if context_size is None:
  612. raise ValueError(
  613. f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
  614. "Known models are: " + ", ".join(model_token_mapping.keys())
  615. )
  616. return context_size
  617. @property
  618. def max_context_size(self) -> int:
  619. """Get max context size for this model."""
  620. return self.modelname_to_contextsize(self.model_name)
  621. def max_tokens_for_prompt(self, prompt: str) -> int:
  622. """Calculate the maximum number of tokens possible to generate for a prompt.
  623. Args:
  624. prompt: The prompt to pass into the model.
  625. Returns:
  626. The maximum number of tokens to generate for a prompt.
  627. Example:
  628. ```python
  629. max_tokens = openai.max_tokens_for_prompt("Tell me a joke.")
  630. ```
  631. """
  632. num_tokens = self.get_num_tokens(prompt)
  633. return self.max_context_size - num_tokens
  634. class OpenAI(BaseOpenAI):
  635. """OpenAI completion model integration.
  636. Setup:
  637. Install `langchain-openai` and set environment variable `OPENAI_API_KEY`.
  638. ```bash
  639. pip install -U langchain-openai
  640. export OPENAI_API_KEY="your-api-key"
  641. ```
  642. Key init args — completion params:
  643. model:
  644. Name of OpenAI model to use.
  645. temperature:
  646. Sampling temperature.
  647. max_tokens:
  648. Max number of tokens to generate.
  649. logprobs:
  650. Whether to return logprobs.
  651. stream_options:
  652. Configure streaming outputs, like whether to return token usage when
  653. streaming (`{"include_usage": True}`).
  654. Key init args — client params:
  655. timeout:
  656. Timeout for requests.
  657. max_retries:
  658. Max number of retries.
  659. api_key:
  660. OpenAI API key. If not passed in will be read from env var `OPENAI_API_KEY`.
  661. base_url:
  662. Base URL for API requests. Only specify if using a proxy or service
  663. emulator.
  664. organization:
  665. OpenAI organization ID. If not passed in will be read from env
  666. var `OPENAI_ORG_ID`.
  667. See full list of supported init args and their descriptions in the params section.
  668. Instantiate:
  669. ```python
  670. from langchain_openai import OpenAI
  671. model = OpenAI(
  672. model="gpt-3.5-turbo-instruct",
  673. temperature=0,
  674. max_retries=2,
  675. # api_key="...",
  676. # base_url="...",
  677. # organization="...",
  678. # other params...
  679. )
  680. ```
  681. Invoke:
  682. ```python
  683. input_text = "The meaning of life is "
  684. model.invoke(input_text)
  685. ```
  686. ```txt
  687. "a philosophical question that has been debated by thinkers and scholars for centuries."
  688. ```
  689. Stream:
  690. ```python
  691. for chunk in model.stream(input_text):
  692. print(chunk, end="|")
  693. ```
  694. ```txt
  695. a| philosophical| question| that| has| been| debated| by| thinkers| and| scholars| for| centuries|.
  696. ```
  697. ```python
  698. "".join(model.stream(input_text))
  699. ```
  700. ```txt
  701. "a philosophical question that has been debated by thinkers and scholars for centuries."
  702. ```
  703. Async:
  704. ```python
  705. await model.ainvoke(input_text)
  706. # stream:
  707. # async for chunk in (await model.astream(input_text)):
  708. # print(chunk)
  709. # batch:
  710. # await model.abatch([input_text])
  711. ```
  712. ```txt
  713. "a philosophical question that has been debated by thinkers and scholars for centuries."
  714. ```
  715. """ # noqa: E501
  716. @classmethod
  717. def get_lc_namespace(cls) -> list[str]:
  718. """Get the namespace of the LangChain object.
  719. Returns:
  720. `["langchain", "llms", "openai"]`
  721. """
  722. return ["langchain", "llms", "openai"]
  723. @classmethod
  724. def is_lc_serializable(cls) -> bool:
  725. """Return whether this model can be serialized by LangChain."""
  726. return True
  727. @property
  728. def _invocation_params(self) -> dict[str, Any]:
  729. return {"model": self.model_name, **super()._invocation_params}
  730. @property
  731. def lc_secrets(self) -> dict[str, str]:
  732. """Mapping of secret keys to environment variables."""
  733. return {"openai_api_key": "OPENAI_API_KEY"}
  734. @property
  735. def lc_attributes(self) -> dict[str, Any]:
  736. """LangChain attributes for this class."""
  737. attributes: dict[str, Any] = {}
  738. if self.openai_api_base:
  739. attributes["openai_api_base"] = self.openai_api_base
  740. if self.openai_organization:
  741. attributes["openai_organization"] = self.openai_organization
  742. if self.openai_proxy:
  743. attributes["openai_proxy"] = self.openai_proxy
  744. return attributes