"""Factory functions for embeddings.""" import functools from importlib import util from typing import Any from langchain_core.embeddings import Embeddings _SUPPORTED_PROVIDERS = { "azure_openai": "langchain_openai", "bedrock": "langchain_aws", "cohere": "langchain_cohere", "google_vertexai": "langchain_google_vertexai", "huggingface": "langchain_huggingface", "mistralai": "langchain_mistralai", "ollama": "langchain_ollama", "openai": "langchain_openai", } def _get_provider_list() -> str: """Get formatted list of providers and their packages.""" return "\n".join(f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()) def _parse_model_string(model_name: str) -> tuple[str, str]: """Parse a model string into provider and model name components. The model string should be in the format 'provider:model-name', where provider is one of the supported providers. Args: model_name: A model string in the format 'provider:model-name' Returns: A tuple of (provider, model_name) ```python _parse_model_string("openai:text-embedding-3-small") # Returns: ("openai", "text-embedding-3-small") _parse_model_string("bedrock:amazon.titan-embed-text-v1") # Returns: ("bedrock", "amazon.titan-embed-text-v1") ``` Raises: ValueError: If the model string is not in the correct format or the provider is unsupported """ if ":" not in model_name: providers = _SUPPORTED_PROVIDERS msg = ( f"Invalid model format '{model_name}'.\n" f"Model name must be in format 'provider:model-name'\n" f"Example valid model strings:\n" f" - openai:text-embedding-3-small\n" f" - bedrock:amazon.titan-embed-text-v1\n" f" - cohere:embed-english-v3.0\n" f"Supported providers: {providers}" ) raise ValueError(msg) provider, model = model_name.split(":", 1) provider = provider.lower().strip() model = model.strip() if provider not in _SUPPORTED_PROVIDERS: msg = ( f"Provider '{provider}' is not supported.\n" f"Supported providers and their required packages:\n" f"{_get_provider_list()}" ) raise ValueError(msg) if not model: msg = "Model name cannot be empty" raise ValueError(msg) return provider, model def _infer_model_and_provider( model: str, *, provider: str | None = None, ) -> tuple[str, str]: if not model.strip(): msg = "Model name cannot be empty" raise ValueError(msg) if provider is None and ":" in model: provider, model_name = _parse_model_string(model) else: model_name = model if not provider: providers = _SUPPORTED_PROVIDERS msg = ( "Must specify either:\n" "1. A model string in format 'provider:model-name'\n" " Example: 'openai:text-embedding-3-small'\n" "2. Or explicitly set provider from: " f"{providers}" ) raise ValueError(msg) if provider not in _SUPPORTED_PROVIDERS: msg = ( f"Provider '{provider}' is not supported.\n" f"Supported providers and their required packages:\n" f"{_get_provider_list()}" ) raise ValueError(msg) return provider, model_name @functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS)) def _check_pkg(pkg: str) -> None: """Check if a package is installed.""" if not util.find_spec(pkg): msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`" raise ImportError(msg) def init_embeddings( model: str, *, provider: str | None = None, **kwargs: Any, ) -> Embeddings: """Initialize an embedding model from a model name and optional provider. !!! note Requires the integration package for the chosen model provider to be installed. See the `model_provider` parameter below for specific package names (e.g., `pip install langchain-openai`). Refer to the [provider integration's API reference](https://docs.langchain.com/oss/python/integrations/providers) for supported model parameters to use as `**kwargs`. Args: model: The name of the model, e.g. `'openai:text-embedding-3-small'`. You can also specify model and model provider in a single argument using `'{model_provider}:{model}'` format, e.g. `'openai:text-embedding-3-small'`. provider: The model provider if not specified as part of the model arg (see above). Supported `provider` values and the corresponding integration package are: - `openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai) - `azure_openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai) - `bedrock` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws) - `cohere` -> [`langchain-cohere`](https://docs.langchain.com/oss/python/integrations/providers/cohere) - `google_vertexai` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google) - `huggingface` -> [`langchain-huggingface`](https://docs.langchain.com/oss/python/integrations/providers/huggingface) - `mistraiai` -> [`langchain-mistralai`](https://docs.langchain.com/oss/python/integrations/providers/mistralai) - `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama) **kwargs: Additional model-specific parameters passed to the embedding model. These vary by provider. Refer to the specific model provider's [integration reference](https://reference.langchain.com/python/integrations/) for all available parameters. Returns: An `Embeddings` instance that can generate embeddings for text. Raises: ValueError: If the model provider is not supported or cannot be determined ImportError: If the required provider package is not installed ???+ example ```python # pip install langchain langchain-openai # Using a model string model = init_embeddings("openai:text-embedding-3-small") model.embed_query("Hello, world!") # Using explicit provider model = init_embeddings(model="text-embedding-3-small", provider="openai") model.embed_documents(["Hello, world!", "Goodbye, world!"]) # With additional parameters model = init_embeddings("openai:text-embedding-3-small", api_key="sk-...") ``` !!! version-added "Added in `langchain` 0.3.9" """ if not model: providers = _SUPPORTED_PROVIDERS.keys() msg = f"Must specify model name. Supported providers are: {', '.join(providers)}" raise ValueError(msg) provider, model_name = _infer_model_and_provider(model, provider=provider) pkg = _SUPPORTED_PROVIDERS[provider] _check_pkg(pkg) if provider == "openai": from langchain_openai import OpenAIEmbeddings return OpenAIEmbeddings(model=model_name, **kwargs) if provider == "azure_openai": from langchain_openai import AzureOpenAIEmbeddings return AzureOpenAIEmbeddings(model=model_name, **kwargs) if provider == "google_vertexai": from langchain_google_vertexai import VertexAIEmbeddings return VertexAIEmbeddings(model=model_name, **kwargs) if provider == "bedrock": from langchain_aws import BedrockEmbeddings return BedrockEmbeddings(model_id=model_name, **kwargs) if provider == "cohere": from langchain_cohere import CohereEmbeddings return CohereEmbeddings(model=model_name, **kwargs) if provider == "mistralai": from langchain_mistralai import MistralAIEmbeddings return MistralAIEmbeddings(model=model_name, **kwargs) if provider == "huggingface": from langchain_huggingface import HuggingFaceEmbeddings return HuggingFaceEmbeddings(model_name=model_name, **kwargs) if provider == "ollama": from langchain_ollama import OllamaEmbeddings return OllamaEmbeddings(model=model_name, **kwargs) msg = ( f"Provider '{provider}' is not supported.\n" f"Supported providers and their required packages:\n" f"{_get_provider_list()}" ) raise ValueError(msg) __all__ = [ "Embeddings", # This one is for backwards compatibility "init_embeddings", ]