base.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772
  1. """Base classes for OpenAI embeddings."""
  2. from __future__ import annotations
  3. import logging
  4. import warnings
  5. from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence
  6. from typing import Any, Literal, cast
  7. import openai
  8. import tiktoken
  9. from langchain_core.embeddings import Embeddings
  10. from langchain_core.runnables.config import run_in_executor
  11. from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env
  12. from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
  13. from typing_extensions import Self
  14. from langchain_openai.chat_models._client_utils import _resolve_sync_and_async_api_keys
  15. logger = logging.getLogger(__name__)
  16. MAX_TOKENS_PER_REQUEST = 300000
  17. """API limit per request for embedding tokens."""
  18. def _process_batched_chunked_embeddings(
  19. num_texts: int,
  20. tokens: list[list[int] | str],
  21. batched_embeddings: list[list[float]],
  22. indices: list[int],
  23. skip_empty: bool,
  24. ) -> list[list[float] | None]:
  25. # for each text, this is the list of embeddings (list of list of floats)
  26. # corresponding to the chunks of the text
  27. results: list[list[list[float]]] = [[] for _ in range(num_texts)]
  28. # for each text, this is the token length of each chunk
  29. # for transformers tokenization, this is the string length
  30. # for tiktoken, this is the number of tokens
  31. num_tokens_in_batch: list[list[int]] = [[] for _ in range(num_texts)]
  32. for i in range(len(indices)):
  33. if skip_empty and len(batched_embeddings[i]) == 1:
  34. continue
  35. results[indices[i]].append(batched_embeddings[i])
  36. num_tokens_in_batch[indices[i]].append(len(tokens[i]))
  37. # for each text, this is the final embedding
  38. embeddings: list[list[float] | None] = []
  39. for i in range(num_texts):
  40. # an embedding for each chunk
  41. _result: list[list[float]] = results[i]
  42. if len(_result) == 0:
  43. # this will be populated with the embedding of an empty string
  44. # in the sync or async code calling this
  45. embeddings.append(None)
  46. continue
  47. if len(_result) == 1:
  48. # if only one embedding was produced, use it
  49. embeddings.append(_result[0])
  50. continue
  51. # else we need to weighted average
  52. # should be same as
  53. # average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
  54. total_weight = sum(num_tokens_in_batch[i])
  55. average = [
  56. sum(
  57. val * weight
  58. for val, weight in zip(embedding, num_tokens_in_batch[i], strict=False)
  59. )
  60. / total_weight
  61. for embedding in zip(*_result, strict=False)
  62. ]
  63. # should be same as
  64. # embeddings.append((average / np.linalg.norm(average)).tolist())
  65. magnitude = sum(val**2 for val in average) ** 0.5
  66. embeddings.append([val / magnitude for val in average])
  67. return embeddings
  68. class OpenAIEmbeddings(BaseModel, Embeddings):
  69. """OpenAI embedding model integration.
  70. Setup:
  71. Install `langchain_openai` and set environment variable `OPENAI_API_KEY`.
  72. ```bash
  73. pip install -U langchain_openai
  74. export OPENAI_API_KEY="your-api-key"
  75. ```
  76. Key init args — embedding params:
  77. model:
  78. Name of OpenAI model to use.
  79. dimensions:
  80. The number of dimensions the resulting output embeddings should have.
  81. Only supported in `'text-embedding-3'` and later models.
  82. Key init args — client params:
  83. api_key:
  84. OpenAI API key.
  85. organization:
  86. OpenAI organization ID. If not passed in will be read
  87. from env var `OPENAI_ORG_ID`.
  88. max_retries:
  89. Maximum number of retries to make when generating.
  90. request_timeout:
  91. Timeout for requests to OpenAI completion API
  92. See full list of supported init args and their descriptions in the params section.
  93. Instantiate:
  94. ```python
  95. from langchain_openai import OpenAIEmbeddings
  96. embed = OpenAIEmbeddings(
  97. model="text-embedding-3-large"
  98. # With the `text-embedding-3` class
  99. # of models, you can specify the size
  100. # of the embeddings you want returned.
  101. # dimensions=1024
  102. )
  103. ```
  104. Embed single text:
  105. ```python
  106. input_text = "The meaning of life is 42"
  107. vector = embeddings.embed_query("hello")
  108. print(vector[:3])
  109. ```
  110. ```python
  111. [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
  112. ```
  113. Embed multiple texts:
  114. ```python
  115. vectors = embeddings.embed_documents(["hello", "goodbye"])
  116. # Showing only the first 3 coordinates
  117. print(len(vectors))
  118. print(vectors[0][:3])
  119. ```
  120. ```python
  121. 2
  122. [-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
  123. ```
  124. Async:
  125. ```python
  126. await embed.aembed_query(input_text)
  127. print(vector[:3])
  128. # multiple:
  129. # await embed.aembed_documents(input_texts)
  130. ```
  131. ```python
  132. [-0.009100092574954033, 0.005071679595857859, -0.0029193938244134188]
  133. ```
  134. """
  135. client: Any = Field(default=None, exclude=True)
  136. async_client: Any = Field(default=None, exclude=True)
  137. model: str = "text-embedding-ada-002"
  138. dimensions: int | None = None
  139. """The number of dimensions the resulting output embeddings should have.
  140. Only supported in `text-embedding-3` and later models.
  141. """
  142. # to support Azure OpenAI Service custom deployment names
  143. deployment: str | None = model
  144. # TODO: Move to AzureOpenAIEmbeddings.
  145. openai_api_version: str | None = Field(
  146. default_factory=from_env("OPENAI_API_VERSION", default=None),
  147. alias="api_version",
  148. )
  149. """Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
  150. # to support Azure OpenAI Service custom endpoints
  151. openai_api_base: str | None = Field(
  152. alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None)
  153. )
  154. """Base URL path for API requests, leave blank if not using a proxy or service
  155. emulator."""
  156. # to support Azure OpenAI Service custom endpoints
  157. openai_api_type: str | None = Field(
  158. default_factory=from_env("OPENAI_API_TYPE", default=None)
  159. )
  160. # to support explicit proxy for OpenAI
  161. openai_proxy: str | None = Field(
  162. default_factory=from_env("OPENAI_PROXY", default=None)
  163. )
  164. embedding_ctx_length: int = 8191
  165. """The maximum number of tokens to embed at once."""
  166. openai_api_key: (
  167. SecretStr | None | Callable[[], str] | Callable[[], Awaitable[str]]
  168. ) = Field(
  169. alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
  170. )
  171. """Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
  172. openai_organization: str | None = Field(
  173. alias="organization",
  174. default_factory=from_env(
  175. ["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None
  176. ),
  177. )
  178. """Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
  179. allowed_special: Literal["all"] | set[str] | None = None
  180. disallowed_special: Literal["all"] | set[str] | Sequence[str] | None = None
  181. chunk_size: int = 1000
  182. """Maximum number of texts to embed in each batch"""
  183. max_retries: int = 2
  184. """Maximum number of retries to make when generating."""
  185. request_timeout: float | tuple[float, float] | Any | None = Field(
  186. default=None, alias="timeout"
  187. )
  188. """Timeout for requests to OpenAI completion API. Can be float, `httpx.Timeout` or
  189. None."""
  190. headers: Any = None
  191. tiktoken_enabled: bool = True
  192. """Set this to False for non-OpenAI implementations of the embeddings API, e.g.
  193. the `--extensions openai` extension for `text-generation-webui`"""
  194. tiktoken_model_name: str | None = None
  195. """The model name to pass to tiktoken when using this class.
  196. Tiktoken is used to count the number of tokens in documents to constrain
  197. them to be under a certain limit.
  198. By default, when set to `None`, this will be the same as the embedding model name.
  199. However, there are some cases where you may want to use this `Embedding` class with
  200. a model name not supported by tiktoken. This can include when using Azure embeddings
  201. or when using one of the many model providers that expose an OpenAI-like
  202. API but with different models. In those cases, in order to avoid erroring
  203. when tiktoken is called, you can specify a model name to use here.
  204. """
  205. show_progress_bar: bool = False
  206. """Whether to show a progress bar when embedding."""
  207. model_kwargs: dict[str, Any] = Field(default_factory=dict)
  208. """Holds any model parameters valid for `create` call not explicitly specified."""
  209. skip_empty: bool = False
  210. """Whether to skip empty strings when embedding or raise an error."""
  211. default_headers: Mapping[str, str] | None = None
  212. default_query: Mapping[str, object] | None = None
  213. # Configure a custom httpx client. See the
  214. # [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
  215. retry_min_seconds: int = 4
  216. """Min number of seconds to wait between retries"""
  217. retry_max_seconds: int = 20
  218. """Max number of seconds to wait between retries"""
  219. http_client: Any | None = None
  220. """Optional `httpx.Client`.
  221. Only used for sync invocations. Must specify `http_async_client` as well if you'd
  222. like a custom client for async invocations.
  223. """
  224. http_async_client: Any | None = None
  225. """Optional `httpx.AsyncClient`.
  226. Only used for async invocations. Must specify `http_client` as well if you'd like a
  227. custom client for sync invocations.
  228. """
  229. check_embedding_ctx_length: bool = True
  230. """Whether to check the token length of inputs and automatically split inputs
  231. longer than embedding_ctx_length."""
  232. model_config = ConfigDict(
  233. extra="forbid", populate_by_name=True, protected_namespaces=()
  234. )
  235. @model_validator(mode="before")
  236. @classmethod
  237. def build_extra(cls, values: dict[str, Any]) -> Any:
  238. """Build extra kwargs from additional params that were passed in."""
  239. all_required_field_names = get_pydantic_field_names(cls)
  240. extra = values.get("model_kwargs", {})
  241. for field_name in list(values):
  242. if field_name in extra:
  243. msg = f"Found {field_name} supplied twice."
  244. raise ValueError(msg)
  245. if field_name not in all_required_field_names:
  246. warnings.warn(
  247. f"""WARNING! {field_name} is not default parameter.
  248. {field_name} was transferred to model_kwargs.
  249. Please confirm that {field_name} is what you intended."""
  250. )
  251. extra[field_name] = values.pop(field_name)
  252. invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
  253. if invalid_model_kwargs:
  254. msg = (
  255. f"Parameters {invalid_model_kwargs} should be specified explicitly. "
  256. f"Instead they were passed in as part of `model_kwargs` parameter."
  257. )
  258. raise ValueError(msg)
  259. values["model_kwargs"] = extra
  260. return values
  261. @model_validator(mode="after")
  262. def validate_environment(self) -> Self:
  263. """Validate that api key and python package exists in environment."""
  264. if self.openai_api_type in ("azure", "azure_ad", "azuread"):
  265. msg = (
  266. "If you are using Azure, please use the `AzureOpenAIEmbeddings` class."
  267. )
  268. raise ValueError(msg)
  269. # Resolve API key from SecretStr or Callable
  270. sync_api_key_value: str | Callable[[], str] | None = None
  271. async_api_key_value: str | Callable[[], Awaitable[str]] | None = None
  272. if self.openai_api_key is not None:
  273. # Because OpenAI and AsyncOpenAI clients support either sync or async
  274. # callables for the API key, we need to resolve separate values here.
  275. sync_api_key_value, async_api_key_value = _resolve_sync_and_async_api_keys(
  276. self.openai_api_key
  277. )
  278. client_params: dict = {
  279. "organization": self.openai_organization,
  280. "base_url": self.openai_api_base,
  281. "timeout": self.request_timeout,
  282. "max_retries": self.max_retries,
  283. "default_headers": self.default_headers,
  284. "default_query": self.default_query,
  285. }
  286. if self.openai_proxy and (self.http_client or self.http_async_client):
  287. openai_proxy = self.openai_proxy
  288. http_client = self.http_client
  289. http_async_client = self.http_async_client
  290. msg = (
  291. "Cannot specify 'openai_proxy' if one of "
  292. "'http_client'/'http_async_client' is already specified. Received:\n"
  293. f"{openai_proxy=}\n{http_client=}\n{http_async_client=}"
  294. )
  295. raise ValueError(msg)
  296. if not self.client:
  297. if sync_api_key_value is None:
  298. # No valid sync API key, leave client as None and raise informative
  299. # error on invocation.
  300. self.client = None
  301. else:
  302. if self.openai_proxy and not self.http_client:
  303. try:
  304. import httpx
  305. except ImportError as e:
  306. msg = (
  307. "Could not import httpx python package. "
  308. "Please install it with `pip install httpx`."
  309. )
  310. raise ImportError(msg) from e
  311. self.http_client = httpx.Client(proxy=self.openai_proxy)
  312. sync_specific = {
  313. "http_client": self.http_client,
  314. "api_key": sync_api_key_value,
  315. }
  316. self.client = openai.OpenAI(**client_params, **sync_specific).embeddings # type: ignore[arg-type]
  317. if not self.async_client:
  318. if self.openai_proxy and not self.http_async_client:
  319. try:
  320. import httpx
  321. except ImportError as e:
  322. msg = (
  323. "Could not import httpx python package. "
  324. "Please install it with `pip install httpx`."
  325. )
  326. raise ImportError(msg) from e
  327. self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy)
  328. async_specific = {
  329. "http_client": self.http_async_client,
  330. "api_key": async_api_key_value,
  331. }
  332. self.async_client = openai.AsyncOpenAI(
  333. **client_params,
  334. **async_specific, # type: ignore[arg-type]
  335. ).embeddings
  336. return self
  337. @property
  338. def _invocation_params(self) -> dict[str, Any]:
  339. params: dict = {"model": self.model, **self.model_kwargs}
  340. if self.dimensions is not None:
  341. params["dimensions"] = self.dimensions
  342. return params
  343. def _ensure_sync_client_available(self) -> None:
  344. """Check that sync client is available, raise error if not."""
  345. if self.client is None:
  346. msg = (
  347. "Sync client is not available. This happens when an async callable "
  348. "was provided for the API key. Use async methods (ainvoke, astream) "
  349. "instead, or provide a string or sync callable for the API key."
  350. )
  351. raise ValueError(msg)
  352. def _tokenize(
  353. self, texts: list[str], chunk_size: int
  354. ) -> tuple[Iterable[int], list[list[int] | str], list[int], list[int]]:
  355. """Tokenize and batch input texts.
  356. Splits texts based on `embedding_ctx_length` and groups them into batches
  357. of size `chunk_size`.
  358. Args:
  359. texts: The list of texts to tokenize.
  360. chunk_size: The maximum number of texts to include in a single batch.
  361. Returns:
  362. A tuple containing:
  363. 1. An iterable of starting indices in the token list for each batch.
  364. 2. A list of tokenized texts (token arrays for tiktoken, strings for
  365. HuggingFace).
  366. 3. An iterable mapping each token array to the index of the original
  367. text. Same length as the token list.
  368. 4. A list of token counts for each tokenized text.
  369. """
  370. tokens: list[list[int] | str] = []
  371. indices: list[int] = []
  372. token_counts: list[int] = []
  373. model_name = self.tiktoken_model_name or self.model
  374. # If tiktoken flag set to False
  375. if not self.tiktoken_enabled:
  376. try:
  377. from transformers import AutoTokenizer
  378. except ImportError:
  379. msg = (
  380. "Could not import transformers python package. "
  381. "This is needed for OpenAIEmbeddings to work without "
  382. "`tiktoken`. Please install it with `pip install transformers`. "
  383. )
  384. raise ValueError(msg)
  385. tokenizer = AutoTokenizer.from_pretrained(
  386. pretrained_model_name_or_path=model_name
  387. )
  388. for i, text in enumerate(texts):
  389. # Tokenize the text using HuggingFace transformers
  390. tokenized: list[int] = tokenizer.encode(text, add_special_tokens=False)
  391. # Split tokens into chunks respecting the embedding_ctx_length
  392. for j in range(0, len(tokenized), self.embedding_ctx_length):
  393. token_chunk: list[int] = tokenized[
  394. j : j + self.embedding_ctx_length
  395. ]
  396. # Convert token IDs back to a string
  397. chunk_text: str = tokenizer.decode(token_chunk)
  398. tokens.append(chunk_text)
  399. indices.append(i)
  400. token_counts.append(len(token_chunk))
  401. else:
  402. try:
  403. encoding = tiktoken.encoding_for_model(model_name)
  404. except KeyError:
  405. encoding = tiktoken.get_encoding("cl100k_base")
  406. encoder_kwargs: dict[str, Any] = {
  407. k: v
  408. for k, v in {
  409. "allowed_special": self.allowed_special,
  410. "disallowed_special": self.disallowed_special,
  411. }.items()
  412. if v is not None
  413. }
  414. for i, text in enumerate(texts):
  415. if self.model.endswith("001"):
  416. # See: https://github.com/openai/openai-python/
  417. # issues/418#issuecomment-1525939500
  418. # replace newlines, which can negatively affect performance.
  419. text = text.replace("\n", " ")
  420. if encoder_kwargs:
  421. token = encoding.encode(text, **encoder_kwargs)
  422. else:
  423. token = encoding.encode_ordinary(text)
  424. # Split tokens into chunks respecting the embedding_ctx_length
  425. for j in range(0, len(token), self.embedding_ctx_length):
  426. tokens.append(token[j : j + self.embedding_ctx_length])
  427. indices.append(i)
  428. token_counts.append(len(token[j : j + self.embedding_ctx_length]))
  429. if self.show_progress_bar:
  430. try:
  431. from tqdm.auto import tqdm
  432. _iter: Iterable = tqdm(range(0, len(tokens), chunk_size))
  433. except ImportError:
  434. _iter = range(0, len(tokens), chunk_size)
  435. else:
  436. _iter = range(0, len(tokens), chunk_size)
  437. return _iter, tokens, indices, token_counts
  438. # please refer to
  439. # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
  440. def _get_len_safe_embeddings(
  441. self,
  442. texts: list[str],
  443. *,
  444. engine: str,
  445. chunk_size: int | None = None,
  446. **kwargs: Any,
  447. ) -> list[list[float]]:
  448. """Generate length-safe embeddings for a list of texts.
  449. This method handles tokenization and embedding generation, respecting the
  450. `embedding_ctx_length` and `chunk_size`. Supports both `tiktoken` and
  451. HuggingFace `transformers` based on the `tiktoken_enabled` flag.
  452. Args:
  453. texts: The list of texts to embed.
  454. engine: The engine or model to use for embeddings.
  455. chunk_size: The size of chunks for processing embeddings.
  456. Returns:
  457. A list of embeddings for each input text.
  458. """
  459. _chunk_size = chunk_size or self.chunk_size
  460. client_kwargs = {**self._invocation_params, **kwargs}
  461. _iter, tokens, indices, token_counts = self._tokenize(texts, _chunk_size)
  462. batched_embeddings: list[list[float]] = []
  463. # Process in batches respecting the token limit
  464. i = 0
  465. while i < len(tokens):
  466. # Determine how many chunks we can include in this batch
  467. batch_token_count = 0
  468. batch_end = i
  469. for j in range(i, min(i + _chunk_size, len(tokens))):
  470. chunk_tokens = token_counts[j]
  471. # Check if adding this chunk would exceed the limit
  472. if batch_token_count + chunk_tokens > MAX_TOKENS_PER_REQUEST:
  473. if batch_end == i:
  474. # Single chunk exceeds limit - handle it anyway
  475. batch_end = j + 1
  476. break
  477. batch_token_count += chunk_tokens
  478. batch_end = j + 1
  479. # Make API call with this batch
  480. batch_tokens = tokens[i:batch_end]
  481. response = self.client.create(input=batch_tokens, **client_kwargs)
  482. if not isinstance(response, dict):
  483. response = response.model_dump()
  484. batched_embeddings.extend(r["embedding"] for r in response["data"])
  485. i = batch_end
  486. embeddings = _process_batched_chunked_embeddings(
  487. len(texts), tokens, batched_embeddings, indices, self.skip_empty
  488. )
  489. _cached_empty_embedding: list[float] | None = None
  490. def empty_embedding() -> list[float]:
  491. nonlocal _cached_empty_embedding
  492. if _cached_empty_embedding is None:
  493. average_embedded = self.client.create(input="", **client_kwargs)
  494. if not isinstance(average_embedded, dict):
  495. average_embedded = average_embedded.model_dump()
  496. _cached_empty_embedding = average_embedded["data"][0]["embedding"]
  497. return _cached_empty_embedding
  498. return [e if e is not None else empty_embedding() for e in embeddings]
  499. # please refer to
  500. # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
  501. async def _aget_len_safe_embeddings(
  502. self,
  503. texts: list[str],
  504. *,
  505. engine: str,
  506. chunk_size: int | None = None,
  507. **kwargs: Any,
  508. ) -> list[list[float]]:
  509. """Asynchronously generate length-safe embeddings for a list of texts.
  510. This method handles tokenization and embedding generation, respecting the
  511. `embedding_ctx_length` and `chunk_size`. Supports both `tiktoken` and
  512. HuggingFace `transformers` based on the `tiktoken_enabled` flag.
  513. Args:
  514. texts: The list of texts to embed.
  515. engine: The engine or model to use for embeddings.
  516. chunk_size: The size of chunks for processing embeddings.
  517. Returns:
  518. A list of embeddings for each input text.
  519. """
  520. _chunk_size = chunk_size or self.chunk_size
  521. client_kwargs = {**self._invocation_params, **kwargs}
  522. _iter, tokens, indices, token_counts = await run_in_executor(
  523. None, self._tokenize, texts, _chunk_size
  524. )
  525. batched_embeddings: list[list[float]] = []
  526. # Process in batches respecting the token limit
  527. i = 0
  528. while i < len(tokens):
  529. # Determine how many chunks we can include in this batch
  530. batch_token_count = 0
  531. batch_end = i
  532. for j in range(i, min(i + _chunk_size, len(tokens))):
  533. chunk_tokens = token_counts[j]
  534. # Check if adding this chunk would exceed the limit
  535. if batch_token_count + chunk_tokens > MAX_TOKENS_PER_REQUEST:
  536. if batch_end == i:
  537. # Single chunk exceeds limit - handle it anyway
  538. batch_end = j + 1
  539. break
  540. batch_token_count += chunk_tokens
  541. batch_end = j + 1
  542. # Make API call with this batch
  543. batch_tokens = tokens[i:batch_end]
  544. response = await self.async_client.create(
  545. input=batch_tokens, **client_kwargs
  546. )
  547. if not isinstance(response, dict):
  548. response = response.model_dump()
  549. batched_embeddings.extend(r["embedding"] for r in response["data"])
  550. i = batch_end
  551. embeddings = _process_batched_chunked_embeddings(
  552. len(texts), tokens, batched_embeddings, indices, self.skip_empty
  553. )
  554. _cached_empty_embedding: list[float] | None = None
  555. async def empty_embedding() -> list[float]:
  556. nonlocal _cached_empty_embedding
  557. if _cached_empty_embedding is None:
  558. average_embedded = await self.async_client.create(
  559. input="", **client_kwargs
  560. )
  561. if not isinstance(average_embedded, dict):
  562. average_embedded = average_embedded.model_dump()
  563. _cached_empty_embedding = average_embedded["data"][0]["embedding"]
  564. return _cached_empty_embedding
  565. return [e if e is not None else await empty_embedding() for e in embeddings]
  566. def embed_documents(
  567. self, texts: list[str], chunk_size: int | None = None, **kwargs: Any
  568. ) -> list[list[float]]:
  569. """Call OpenAI's embedding endpoint to embed search docs.
  570. Args:
  571. texts: The list of texts to embed.
  572. chunk_size: The chunk size of embeddings.
  573. If `None`, will use the chunk size specified by the class.
  574. kwargs: Additional keyword arguments to pass to the embedding API.
  575. Returns:
  576. List of embeddings, one for each text.
  577. """
  578. self._ensure_sync_client_available()
  579. chunk_size_ = chunk_size or self.chunk_size
  580. client_kwargs = {**self._invocation_params, **kwargs}
  581. if not self.check_embedding_ctx_length:
  582. embeddings: list[list[float]] = []
  583. for i in range(0, len(texts), chunk_size_):
  584. response = self.client.create(
  585. input=texts[i : i + chunk_size_], **client_kwargs
  586. )
  587. if not isinstance(response, dict):
  588. response = response.model_dump()
  589. embeddings.extend(r["embedding"] for r in response["data"])
  590. return embeddings
  591. # Unconditionally call _get_len_safe_embeddings to handle length safety.
  592. # This could be optimized to avoid double work when all texts are short enough.
  593. engine = cast(str, self.deployment)
  594. return self._get_len_safe_embeddings(
  595. texts, engine=engine, chunk_size=chunk_size, **kwargs
  596. )
  597. async def aembed_documents(
  598. self, texts: list[str], chunk_size: int | None = None, **kwargs: Any
  599. ) -> list[list[float]]:
  600. """Asynchronously call OpenAI's embedding endpoint to embed search docs.
  601. Args:
  602. texts: The list of texts to embed.
  603. chunk_size: The chunk size of embeddings.
  604. If `None`, will use the chunk size specified by the class.
  605. kwargs: Additional keyword arguments to pass to the embedding API.
  606. Returns:
  607. List of embeddings, one for each text.
  608. """
  609. chunk_size_ = chunk_size or self.chunk_size
  610. client_kwargs = {**self._invocation_params, **kwargs}
  611. if not self.check_embedding_ctx_length:
  612. embeddings: list[list[float]] = []
  613. for i in range(0, len(texts), chunk_size_):
  614. response = await self.async_client.create(
  615. input=texts[i : i + chunk_size_], **client_kwargs
  616. )
  617. if not isinstance(response, dict):
  618. response = response.model_dump()
  619. embeddings.extend(r["embedding"] for r in response["data"])
  620. return embeddings
  621. # Unconditionally call _get_len_safe_embeddings to handle length safety.
  622. # This could be optimized to avoid double work when all texts are short enough.
  623. engine = cast(str, self.deployment)
  624. return await self._aget_len_safe_embeddings(
  625. texts, engine=engine, chunk_size=chunk_size, **kwargs
  626. )
  627. def embed_query(self, text: str, **kwargs: Any) -> list[float]:
  628. """Call out to OpenAI's embedding endpoint for embedding query text.
  629. Args:
  630. text: The text to embed.
  631. kwargs: Additional keyword arguments to pass to the embedding API.
  632. Returns:
  633. Embedding for the text.
  634. """
  635. self._ensure_sync_client_available()
  636. return self.embed_documents([text], **kwargs)[0]
  637. async def aembed_query(self, text: str, **kwargs: Any) -> list[float]:
  638. """Call out to OpenAI's embedding endpoint async for embedding query text.
  639. Args:
  640. text: The text to embed.
  641. kwargs: Additional keyword arguments to pass to the embedding API.
  642. Returns:
  643. Embedding for the text.
  644. """
  645. embeddings = await self.aembed_documents([text], **kwargs)
  646. return embeddings[0]