| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572 |
- from __future__ import annotations
- import functools
- import logging
- from collections import defaultdict
- from collections.abc import Mapping
- from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Optional,
- TypeVar,
- Union,
- )
- from typing_extensions import TypedDict
- from langsmith import client as ls_client
- from langsmith import run_helpers
- from langsmith.schemas import InputTokenDetails, OutputTokenDetails, UsageMetadata
- if TYPE_CHECKING:
- from openai import AsyncOpenAI, OpenAI
- from openai.types.chat.chat_completion_chunk import (
- ChatCompletionChunk,
- Choice,
- ChoiceDeltaToolCall,
- )
- from openai.types.completion import Completion
- from openai.types.responses import ResponseStreamEvent # type: ignore
- # Any is used since it may work with Azure or other providers
- C = TypeVar("C", bound=Union["OpenAI", "AsyncOpenAI", Any])
- logger = logging.getLogger(__name__)
- @functools.lru_cache
- def _get_omit_types() -> tuple[type, ...]:
- """Get NotGiven/Omit sentinel types used by OpenAI SDK."""
- types: list[type[Any]] = []
- try:
- from openai._types import NotGiven, Omit
- types.append(NotGiven)
- types.append(Omit)
- except ImportError:
- pass
- return tuple(types)
- def _strip_not_given(d: dict) -> dict:
- try:
- omit_types = _get_omit_types()
- if not omit_types:
- return d
- return {
- k: v
- for k, v in d.items()
- if not (isinstance(v, omit_types) or (k.startswith("extra_") and v is None))
- }
- except Exception as e:
- logger.error(f"Error stripping NotGiven: {e}")
- return d
- def _process_inputs(d: dict) -> dict:
- """Strip `NotGiven` values and serialize `text_format` to JSON schema."""
- d = _strip_not_given(d)
- # Convert text_format (Pydantic model) to JSON schema if present
- if "text_format" in d:
- text_format = d["text_format"]
- if hasattr(text_format, "model_json_schema"):
- try:
- return {
- **d,
- "text_format": text_format.model_json_schema(),
- }
- except Exception:
- pass
- return d
- def _infer_invocation_params(model_type: str, provider: str, kwargs: dict):
- stripped = _strip_not_given(kwargs)
- stop = stripped.get("stop")
- if stop and isinstance(stop, str):
- stop = [stop]
- # Allowlist of safe invocation parameters to include
- # Only include known, non-sensitive parameters
- allowed_invocation_keys = {
- "frequency_penalty",
- "n",
- "logit_bias",
- "logprobs",
- "modalities",
- "parallel_tool_calls",
- "prediction",
- "presence_penalty",
- "prompt_cache_key",
- "reasoning",
- "reasoning_effort",
- "response_format",
- "seed",
- "service_tier",
- "stream_options",
- "top_logprobs",
- "top_p",
- "truncation",
- "user",
- "verbosity",
- "web_search_options",
- }
- # Only include allowlisted parameters
- invocation_params = {
- k: v for k, v in stripped.items() if k in allowed_invocation_keys
- }
- return {
- "ls_provider": provider,
- "ls_model_type": model_type,
- "ls_model_name": stripped.get("model"),
- "ls_temperature": stripped.get("temperature"),
- "ls_max_tokens": stripped.get("max_tokens")
- or stripped.get("max_completion_tokens")
- or stripped.get("max_output_tokens"),
- "ls_stop": stop,
- "ls_invocation_params": invocation_params,
- }
- def _reduce_choices(choices: list[Choice]) -> dict:
- reversed_choices = list(reversed(choices))
- message: dict[str, Any] = {
- "role": "assistant",
- "content": "",
- }
- for c in reversed_choices:
- if hasattr(c, "delta") and getattr(c.delta, "role", None):
- message["role"] = c.delta.role
- break
- tool_calls: defaultdict[int, list[ChoiceDeltaToolCall]] = defaultdict(list)
- for c in choices:
- if hasattr(c, "delta"):
- if getattr(c.delta, "content", None):
- message["content"] += c.delta.content
- if getattr(c.delta, "function_call", None):
- if not message.get("function_call"):
- message["function_call"] = {"name": "", "arguments": ""}
- name_ = getattr(c.delta.function_call, "name", None)
- if name_:
- message["function_call"]["name"] += name_
- arguments_ = getattr(c.delta.function_call, "arguments", None)
- if arguments_:
- message["function_call"]["arguments"] += arguments_
- if getattr(c.delta, "tool_calls", None):
- tool_calls_list = c.delta.tool_calls
- if tool_calls_list is not None:
- for tool_call in tool_calls_list:
- tool_calls[tool_call.index].append(tool_call)
- if tool_calls:
- message["tool_calls"] = [None for _ in range(max(tool_calls.keys()) + 1)]
- for index, tool_call_chunks in tool_calls.items():
- message["tool_calls"][index] = {
- "index": index,
- "id": next((c.id for c in tool_call_chunks if c.id), None),
- "type": next((c.type for c in tool_call_chunks if c.type), None),
- "function": {"name": "", "arguments": ""},
- }
- for chunk in tool_call_chunks:
- if getattr(chunk, "function", None):
- name_ = getattr(chunk.function, "name", None)
- if name_:
- message["tool_calls"][index]["function"]["name"] += name_
- arguments_ = getattr(chunk.function, "arguments", None)
- if arguments_:
- message["tool_calls"][index]["function"]["arguments"] += (
- arguments_
- )
- return {
- "index": getattr(choices[0], "index", 0) if choices else 0,
- "finish_reason": next(
- (
- c.finish_reason
- for c in reversed_choices
- if getattr(c, "finish_reason", None)
- ),
- None,
- ),
- "message": message,
- }
- def _reduce_chat(all_chunks: list[ChatCompletionChunk]) -> dict:
- choices_by_index: defaultdict[int, list[Choice]] = defaultdict(list)
- for chunk in all_chunks:
- for choice in chunk.choices:
- choices_by_index[choice.index].append(choice)
- if all_chunks:
- d = all_chunks[-1].model_dump()
- d["choices"] = [
- _reduce_choices(choices) for choices in choices_by_index.values()
- ]
- else:
- d = {"choices": [{"message": {"role": "assistant", "content": ""}}]}
- # streamed outputs don't go through `process_outputs`
- # so we need to flatten metadata here
- oai_token_usage = d.pop("usage", None)
- d["usage_metadata"] = (
- _create_usage_metadata(oai_token_usage) if oai_token_usage else None
- )
- return d
- def _reduce_completions(all_chunks: list[Completion]) -> dict:
- all_content = []
- for chunk in all_chunks:
- content = chunk.choices[0].text
- if content is not None:
- all_content.append(content)
- content = "".join(all_content)
- if all_chunks:
- d = all_chunks[-1].model_dump()
- d["choices"] = [{"text": content}]
- else:
- d = {"choices": [{"text": content}]}
- return d
- def _create_usage_metadata(
- oai_token_usage: dict, service_tier: Optional[str] = None
- ) -> UsageMetadata:
- recognized_service_tier = (
- service_tier if service_tier in ["priority", "flex"] else None
- )
- service_tier_prefix = (
- f"{recognized_service_tier}_" if recognized_service_tier else ""
- )
- input_tokens = (
- oai_token_usage.get("prompt_tokens") or oai_token_usage.get("input_tokens") or 0
- )
- output_tokens = (
- oai_token_usage.get("completion_tokens")
- or oai_token_usage.get("output_tokens")
- or 0
- )
- total_tokens = oai_token_usage.get("total_tokens") or input_tokens + output_tokens
- input_token_details: dict = {
- "audio": (
- oai_token_usage.get("prompt_tokens_details")
- or oai_token_usage.get("input_tokens_details")
- or {}
- ).get("audio_tokens"),
- f"{service_tier_prefix}cache_read": (
- oai_token_usage.get("prompt_tokens_details")
- or oai_token_usage.get("input_tokens_details")
- or {}
- ).get("cached_tokens"),
- }
- output_token_details: dict = {
- "audio": (
- oai_token_usage.get("completion_tokens_details")
- or oai_token_usage.get("output_tokens_details")
- or {}
- ).get("audio_tokens"),
- f"{service_tier_prefix}reasoning": (
- oai_token_usage.get("completion_tokens_details")
- or oai_token_usage.get("output_tokens_details")
- or {}
- ).get("reasoning_tokens"),
- }
- if recognized_service_tier:
- # Avoid counting cache read and reasoning tokens towards the
- # service tier token count since service tier tokens are already
- # priced differently
- input_token_details[recognized_service_tier] = input_tokens - (
- input_token_details.get(f"{service_tier_prefix}cache_read") or 0
- )
- output_token_details[recognized_service_tier] = output_tokens - (
- output_token_details.get(f"{service_tier_prefix}reasoning") or 0
- )
- return UsageMetadata(
- input_tokens=input_tokens,
- output_tokens=output_tokens,
- total_tokens=total_tokens,
- input_token_details=InputTokenDetails(
- **{k: v for k, v in input_token_details.items() if v is not None}
- ),
- output_token_details=OutputTokenDetails(
- **{k: v for k, v in output_token_details.items() if v is not None}
- ),
- )
- def _process_chat_completion(outputs: Any):
- try:
- rdict = outputs.model_dump()
- oai_token_usage = rdict.pop("usage", None)
- rdict["usage_metadata"] = (
- _create_usage_metadata(oai_token_usage, rdict.get("service_tier"))
- if oai_token_usage
- else None
- )
- return rdict
- except BaseException as e:
- logger.debug(f"Error processing chat completion: {e}")
- return {"output": outputs}
- def _get_wrapper(
- original_create: Callable,
- name: str,
- reduce_fn: Callable,
- tracing_extra: Optional[TracingExtra] = None,
- invocation_params_fn: Optional[Callable] = None,
- process_outputs: Optional[Callable] = None,
- ) -> Callable:
- textra = tracing_extra or {}
- @functools.wraps(original_create)
- def create(*args, **kwargs):
- decorator = run_helpers.traceable(
- name=name,
- run_type="llm",
- reduce_fn=reduce_fn if kwargs.get("stream") is True else None,
- process_inputs=_process_inputs,
- _invocation_params_fn=invocation_params_fn,
- process_outputs=process_outputs,
- **textra,
- )
- return decorator(original_create)(*args, **kwargs)
- @functools.wraps(original_create)
- async def acreate(*args, **kwargs):
- decorator = run_helpers.traceable(
- name=name,
- run_type="llm",
- reduce_fn=reduce_fn if kwargs.get("stream") is True else None,
- process_inputs=_process_inputs,
- _invocation_params_fn=invocation_params_fn,
- process_outputs=process_outputs,
- **textra,
- )
- return await decorator(original_create)(*args, **kwargs)
- return acreate if run_helpers.is_async(original_create) else create
- def _get_parse_wrapper(
- original_parse: Callable,
- name: str,
- process_outputs: Callable,
- tracing_extra: Optional[TracingExtra] = None,
- invocation_params_fn: Optional[Callable] = None,
- ) -> Callable:
- textra = tracing_extra or {}
- @functools.wraps(original_parse)
- def parse(*args, **kwargs):
- decorator = run_helpers.traceable(
- name=name,
- run_type="llm",
- reduce_fn=None,
- process_inputs=_process_inputs,
- _invocation_params_fn=invocation_params_fn,
- process_outputs=process_outputs,
- **textra,
- )
- return decorator(original_parse)(*args, **kwargs)
- @functools.wraps(original_parse)
- async def aparse(*args, **kwargs):
- decorator = run_helpers.traceable(
- name=name,
- run_type="llm",
- reduce_fn=None,
- process_inputs=_process_inputs,
- _invocation_params_fn=invocation_params_fn,
- process_outputs=process_outputs,
- **textra,
- )
- return await decorator(original_parse)(*args, **kwargs)
- return aparse if run_helpers.is_async(original_parse) else parse
- def _reduce_response_events(events: list[ResponseStreamEvent]) -> dict:
- for event in events:
- if event.type == "response.completed":
- return _process_responses_api_output(event.response)
- return {}
- class TracingExtra(TypedDict, total=False):
- metadata: Optional[Mapping[str, Any]]
- tags: Optional[list[str]]
- client: Optional[ls_client.Client]
- def wrap_openai(
- client: C,
- *,
- tracing_extra: Optional[TracingExtra] = None,
- chat_name: str = "ChatOpenAI",
- completions_name: str = "OpenAI",
- ) -> C:
- """Patch the OpenAI client to make it traceable.
- Supports:
- - Chat and Responses API's
- - Sync and async OpenAI clients
- - `create` and `parse` methods
- - With and without streaming
- Args:
- client: The client to patch.
- tracing_extra: Extra tracing information.
- chat_name: The run name for the chat completions endpoint.
- completions_name: The run name for the completions endpoint.
- Returns:
- The patched client.
- Example:
- ```python
- import openai
- from langsmith import wrappers
- # Use OpenAI client same as you normally would.
- client = wrappers.wrap_openai(openai.OpenAI())
- # Chat API:
- messages = [
- {"role": "system", "content": "You are a helpful assistant."},
- {
- "role": "user",
- "content": "What physics breakthroughs do you predict will happen by 2300?",
- },
- ]
- completion = client.chat.completions.create(
- model="gpt-4o-mini", messages=messages
- )
- print(completion.choices[0].message.content)
- # Responses API:
- response = client.responses.create(
- model="gpt-4o-mini",
- messages=messages,
- )
- print(response.output_text)
- ```
- !!! warning "Behavior changed in `langsmith` 0.3.16"
- Support for Responses API added.
- """ # noqa: E501
- tracing_extra = tracing_extra or {}
- ls_provider = "openai"
- try:
- from openai import AsyncAzureOpenAI, AzureOpenAI
- if isinstance(client, AzureOpenAI) or isinstance(client, AsyncAzureOpenAI):
- ls_provider = "azure"
- chat_name = "AzureChatOpenAI"
- completions_name = "AzureOpenAI"
- except ImportError:
- pass
- # First wrap the create methods - these handle non-streaming cases
- client.chat.completions.create = _get_wrapper( # type: ignore[method-assign]
- client.chat.completions.create,
- chat_name,
- _reduce_chat,
- tracing_extra=tracing_extra,
- invocation_params_fn=functools.partial(
- _infer_invocation_params, "chat", ls_provider
- ),
- process_outputs=_process_chat_completion,
- )
- client.completions.create = _get_wrapper( # type: ignore[method-assign]
- client.completions.create,
- completions_name,
- _reduce_completions,
- tracing_extra=tracing_extra,
- invocation_params_fn=functools.partial(
- _infer_invocation_params, "llm", ls_provider
- ),
- )
- # Wrap beta.chat.completions.parse if it exists
- if (
- hasattr(client, "beta")
- and hasattr(client.beta, "chat")
- and hasattr(client.beta.chat, "completions")
- and hasattr(client.beta.chat.completions, "parse")
- ):
- client.beta.chat.completions.parse = _get_parse_wrapper( # type: ignore[method-assign]
- client.beta.chat.completions.parse, # type: ignore
- chat_name,
- _process_chat_completion,
- tracing_extra=tracing_extra,
- invocation_params_fn=functools.partial(
- _infer_invocation_params, "chat", ls_provider
- ),
- )
- # Wrap chat.completions.parse if it exists
- if (
- hasattr(client, "chat")
- and hasattr(client.chat, "completions")
- and hasattr(client.chat.completions, "parse")
- ):
- client.chat.completions.parse = _get_parse_wrapper( # type: ignore[method-assign]
- client.chat.completions.parse, # type: ignore
- chat_name,
- _process_chat_completion,
- tracing_extra=tracing_extra,
- invocation_params_fn=functools.partial(
- _infer_invocation_params, "chat", ls_provider
- ),
- )
- # For the responses API: "client.responses.create(**kwargs)"
- if hasattr(client, "responses"):
- if hasattr(client.responses, "create"):
- client.responses.create = _get_wrapper( # type: ignore[method-assign]
- client.responses.create,
- chat_name,
- _reduce_response_events,
- process_outputs=_process_responses_api_output,
- tracing_extra=tracing_extra,
- invocation_params_fn=functools.partial(
- _infer_invocation_params, "chat", ls_provider
- ),
- )
- if hasattr(client.responses, "parse"):
- client.responses.parse = _get_parse_wrapper( # type: ignore[method-assign]
- client.responses.parse,
- chat_name,
- _process_responses_api_output,
- tracing_extra=tracing_extra,
- invocation_params_fn=functools.partial(
- _infer_invocation_params, "chat", ls_provider
- ),
- )
- return client
- def _process_responses_api_output(response: Any) -> dict:
- if response:
- try:
- output = response.model_dump(exclude_none=True, mode="json")
- if usage := output.pop("usage", None):
- output["usage_metadata"] = _create_usage_metadata(
- usage, output.get("service_tier")
- )
- return output
- except Exception:
- return {"output": response}
- return {}
|