base.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. """Factory functions for embeddings."""
  2. import functools
  3. from importlib import util
  4. from typing import Any
  5. from langchain_core.embeddings import Embeddings
  6. _SUPPORTED_PROVIDERS = {
  7. "azure_openai": "langchain_openai",
  8. "bedrock": "langchain_aws",
  9. "cohere": "langchain_cohere",
  10. "google_vertexai": "langchain_google_vertexai",
  11. "huggingface": "langchain_huggingface",
  12. "mistralai": "langchain_mistralai",
  13. "ollama": "langchain_ollama",
  14. "openai": "langchain_openai",
  15. }
  16. def _get_provider_list() -> str:
  17. """Get formatted list of providers and their packages."""
  18. return "\n".join(f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items())
  19. def _parse_model_string(model_name: str) -> tuple[str, str]:
  20. """Parse a model string into provider and model name components.
  21. The model string should be in the format 'provider:model-name', where provider
  22. is one of the supported providers.
  23. Args:
  24. model_name: A model string in the format 'provider:model-name'
  25. Returns:
  26. A tuple of (provider, model_name)
  27. ```python
  28. _parse_model_string("openai:text-embedding-3-small")
  29. # Returns: ("openai", "text-embedding-3-small")
  30. _parse_model_string("bedrock:amazon.titan-embed-text-v1")
  31. # Returns: ("bedrock", "amazon.titan-embed-text-v1")
  32. ```
  33. Raises:
  34. ValueError: If the model string is not in the correct format or
  35. the provider is unsupported
  36. """
  37. if ":" not in model_name:
  38. providers = _SUPPORTED_PROVIDERS
  39. msg = (
  40. f"Invalid model format '{model_name}'.\n"
  41. f"Model name must be in format 'provider:model-name'\n"
  42. f"Example valid model strings:\n"
  43. f" - openai:text-embedding-3-small\n"
  44. f" - bedrock:amazon.titan-embed-text-v1\n"
  45. f" - cohere:embed-english-v3.0\n"
  46. f"Supported providers: {providers}"
  47. )
  48. raise ValueError(msg)
  49. provider, model = model_name.split(":", 1)
  50. provider = provider.lower().strip()
  51. model = model.strip()
  52. if provider not in _SUPPORTED_PROVIDERS:
  53. msg = (
  54. f"Provider '{provider}' is not supported.\n"
  55. f"Supported providers and their required packages:\n"
  56. f"{_get_provider_list()}"
  57. )
  58. raise ValueError(msg)
  59. if not model:
  60. msg = "Model name cannot be empty"
  61. raise ValueError(msg)
  62. return provider, model
  63. def _infer_model_and_provider(
  64. model: str,
  65. *,
  66. provider: str | None = None,
  67. ) -> tuple[str, str]:
  68. if not model.strip():
  69. msg = "Model name cannot be empty"
  70. raise ValueError(msg)
  71. if provider is None and ":" in model:
  72. provider, model_name = _parse_model_string(model)
  73. else:
  74. model_name = model
  75. if not provider:
  76. providers = _SUPPORTED_PROVIDERS
  77. msg = (
  78. "Must specify either:\n"
  79. "1. A model string in format 'provider:model-name'\n"
  80. " Example: 'openai:text-embedding-3-small'\n"
  81. "2. Or explicitly set provider from: "
  82. f"{providers}"
  83. )
  84. raise ValueError(msg)
  85. if provider not in _SUPPORTED_PROVIDERS:
  86. msg = (
  87. f"Provider '{provider}' is not supported.\n"
  88. f"Supported providers and their required packages:\n"
  89. f"{_get_provider_list()}"
  90. )
  91. raise ValueError(msg)
  92. return provider, model_name
  93. @functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
  94. def _check_pkg(pkg: str) -> None:
  95. """Check if a package is installed."""
  96. if not util.find_spec(pkg):
  97. msg = f"Could not import {pkg} python package. Please install it with `pip install {pkg}`"
  98. raise ImportError(msg)
  99. def init_embeddings(
  100. model: str,
  101. *,
  102. provider: str | None = None,
  103. **kwargs: Any,
  104. ) -> Embeddings:
  105. """Initialize an embedding model from a model name and optional provider.
  106. !!! note
  107. Requires the integration package for the chosen model provider to be installed.
  108. See the `model_provider` parameter below for specific package names
  109. (e.g., `pip install langchain-openai`).
  110. Refer to the [provider integration's API reference](https://docs.langchain.com/oss/python/integrations/providers)
  111. for supported model parameters to use as `**kwargs`.
  112. Args:
  113. model: The name of the model, e.g. `'openai:text-embedding-3-small'`.
  114. You can also specify model and model provider in a single argument using
  115. `'{model_provider}:{model}'` format, e.g. `'openai:text-embedding-3-small'`.
  116. provider: The model provider if not specified as part of the model arg
  117. (see above).
  118. Supported `provider` values and the corresponding integration package
  119. are:
  120. - `openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
  121. - `azure_openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
  122. - `bedrock` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
  123. - `cohere` -> [`langchain-cohere`](https://docs.langchain.com/oss/python/integrations/providers/cohere)
  124. - `google_vertexai` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
  125. - `huggingface` -> [`langchain-huggingface`](https://docs.langchain.com/oss/python/integrations/providers/huggingface)
  126. - `mistraiai` -> [`langchain-mistralai`](https://docs.langchain.com/oss/python/integrations/providers/mistralai)
  127. - `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama)
  128. **kwargs: Additional model-specific parameters passed to the embedding model.
  129. These vary by provider. Refer to the specific model provider's
  130. [integration reference](https://reference.langchain.com/python/integrations/)
  131. for all available parameters.
  132. Returns:
  133. An `Embeddings` instance that can generate embeddings for text.
  134. Raises:
  135. ValueError: If the model provider is not supported or cannot be determined
  136. ImportError: If the required provider package is not installed
  137. ???+ example
  138. ```python
  139. # pip install langchain langchain-openai
  140. # Using a model string
  141. model = init_embeddings("openai:text-embedding-3-small")
  142. model.embed_query("Hello, world!")
  143. # Using explicit provider
  144. model = init_embeddings(model="text-embedding-3-small", provider="openai")
  145. model.embed_documents(["Hello, world!", "Goodbye, world!"])
  146. # With additional parameters
  147. model = init_embeddings("openai:text-embedding-3-small", api_key="sk-...")
  148. ```
  149. !!! version-added "Added in `langchain` 0.3.9"
  150. """
  151. if not model:
  152. providers = _SUPPORTED_PROVIDERS.keys()
  153. msg = f"Must specify model name. Supported providers are: {', '.join(providers)}"
  154. raise ValueError(msg)
  155. provider, model_name = _infer_model_and_provider(model, provider=provider)
  156. pkg = _SUPPORTED_PROVIDERS[provider]
  157. _check_pkg(pkg)
  158. if provider == "openai":
  159. from langchain_openai import OpenAIEmbeddings
  160. return OpenAIEmbeddings(model=model_name, **kwargs)
  161. if provider == "azure_openai":
  162. from langchain_openai import AzureOpenAIEmbeddings
  163. return AzureOpenAIEmbeddings(model=model_name, **kwargs)
  164. if provider == "google_vertexai":
  165. from langchain_google_vertexai import VertexAIEmbeddings
  166. return VertexAIEmbeddings(model=model_name, **kwargs)
  167. if provider == "bedrock":
  168. from langchain_aws import BedrockEmbeddings
  169. return BedrockEmbeddings(model_id=model_name, **kwargs)
  170. if provider == "cohere":
  171. from langchain_cohere import CohereEmbeddings
  172. return CohereEmbeddings(model=model_name, **kwargs)
  173. if provider == "mistralai":
  174. from langchain_mistralai import MistralAIEmbeddings
  175. return MistralAIEmbeddings(model=model_name, **kwargs)
  176. if provider == "huggingface":
  177. from langchain_huggingface import HuggingFaceEmbeddings
  178. return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
  179. if provider == "ollama":
  180. from langchain_ollama import OllamaEmbeddings
  181. return OllamaEmbeddings(model=model_name, **kwargs)
  182. msg = (
  183. f"Provider '{provider}' is not supported.\n"
  184. f"Supported providers and their required packages:\n"
  185. f"{_get_provider_list()}"
  186. )
  187. raise ValueError(msg)
  188. __all__ = [
  189. "Embeddings", # This one is for backwards compatibility
  190. "init_embeddings",
  191. ]