from __future__ import annotations import functools import logging from collections import defaultdict from collections.abc import Mapping from typing import ( TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, ) from typing_extensions import TypedDict from langsmith import client as ls_client from langsmith import run_helpers from langsmith.schemas import InputTokenDetails, OutputTokenDetails, UsageMetadata if TYPE_CHECKING: from openai import AsyncOpenAI, OpenAI from openai.types.chat.chat_completion_chunk import ( ChatCompletionChunk, Choice, ChoiceDeltaToolCall, ) from openai.types.completion import Completion from openai.types.responses import ResponseStreamEvent # type: ignore # Any is used since it may work with Azure or other providers C = TypeVar("C", bound=Union["OpenAI", "AsyncOpenAI", Any]) logger = logging.getLogger(__name__) @functools.lru_cache def _get_omit_types() -> tuple[type, ...]: """Get NotGiven/Omit sentinel types used by OpenAI SDK.""" types: list[type[Any]] = [] try: from openai._types import NotGiven, Omit types.append(NotGiven) types.append(Omit) except ImportError: pass return tuple(types) def _strip_not_given(d: dict) -> dict: try: omit_types = _get_omit_types() if not omit_types: return d return { k: v for k, v in d.items() if not (isinstance(v, omit_types) or (k.startswith("extra_") and v is None)) } except Exception as e: logger.error(f"Error stripping NotGiven: {e}") return d def _process_inputs(d: dict) -> dict: """Strip `NotGiven` values and serialize `text_format` to JSON schema.""" d = _strip_not_given(d) # Convert text_format (Pydantic model) to JSON schema if present if "text_format" in d: text_format = d["text_format"] if hasattr(text_format, "model_json_schema"): try: return { **d, "text_format": text_format.model_json_schema(), } except Exception: pass return d def _infer_invocation_params(model_type: str, provider: str, kwargs: dict): stripped = _strip_not_given(kwargs) stop = stripped.get("stop") if stop and isinstance(stop, str): stop = [stop] # Allowlist of safe invocation parameters to include # Only include known, non-sensitive parameters allowed_invocation_keys = { "frequency_penalty", "n", "logit_bias", "logprobs", "modalities", "parallel_tool_calls", "prediction", "presence_penalty", "prompt_cache_key", "reasoning", "reasoning_effort", "response_format", "seed", "service_tier", "stream_options", "top_logprobs", "top_p", "truncation", "user", "verbosity", "web_search_options", } # Only include allowlisted parameters invocation_params = { k: v for k, v in stripped.items() if k in allowed_invocation_keys } return { "ls_provider": provider, "ls_model_type": model_type, "ls_model_name": stripped.get("model"), "ls_temperature": stripped.get("temperature"), "ls_max_tokens": stripped.get("max_tokens") or stripped.get("max_completion_tokens") or stripped.get("max_output_tokens"), "ls_stop": stop, "ls_invocation_params": invocation_params, } def _reduce_choices(choices: list[Choice]) -> dict: reversed_choices = list(reversed(choices)) message: dict[str, Any] = { "role": "assistant", "content": "", } for c in reversed_choices: if hasattr(c, "delta") and getattr(c.delta, "role", None): message["role"] = c.delta.role break tool_calls: defaultdict[int, list[ChoiceDeltaToolCall]] = defaultdict(list) for c in choices: if hasattr(c, "delta"): if getattr(c.delta, "content", None): message["content"] += c.delta.content if getattr(c.delta, "function_call", None): if not message.get("function_call"): message["function_call"] = {"name": "", "arguments": ""} name_ = getattr(c.delta.function_call, "name", None) if name_: message["function_call"]["name"] += name_ arguments_ = getattr(c.delta.function_call, "arguments", None) if arguments_: message["function_call"]["arguments"] += arguments_ if getattr(c.delta, "tool_calls", None): tool_calls_list = c.delta.tool_calls if tool_calls_list is not None: for tool_call in tool_calls_list: tool_calls[tool_call.index].append(tool_call) if tool_calls: message["tool_calls"] = [None for _ in range(max(tool_calls.keys()) + 1)] for index, tool_call_chunks in tool_calls.items(): message["tool_calls"][index] = { "index": index, "id": next((c.id for c in tool_call_chunks if c.id), None), "type": next((c.type for c in tool_call_chunks if c.type), None), "function": {"name": "", "arguments": ""}, } for chunk in tool_call_chunks: if getattr(chunk, "function", None): name_ = getattr(chunk.function, "name", None) if name_: message["tool_calls"][index]["function"]["name"] += name_ arguments_ = getattr(chunk.function, "arguments", None) if arguments_: message["tool_calls"][index]["function"]["arguments"] += ( arguments_ ) return { "index": getattr(choices[0], "index", 0) if choices else 0, "finish_reason": next( ( c.finish_reason for c in reversed_choices if getattr(c, "finish_reason", None) ), None, ), "message": message, } def _reduce_chat(all_chunks: list[ChatCompletionChunk]) -> dict: choices_by_index: defaultdict[int, list[Choice]] = defaultdict(list) for chunk in all_chunks: for choice in chunk.choices: choices_by_index[choice.index].append(choice) if all_chunks: d = all_chunks[-1].model_dump() d["choices"] = [ _reduce_choices(choices) for choices in choices_by_index.values() ] else: d = {"choices": [{"message": {"role": "assistant", "content": ""}}]} # streamed outputs don't go through `process_outputs` # so we need to flatten metadata here oai_token_usage = d.pop("usage", None) d["usage_metadata"] = ( _create_usage_metadata(oai_token_usage) if oai_token_usage else None ) return d def _reduce_completions(all_chunks: list[Completion]) -> dict: all_content = [] for chunk in all_chunks: content = chunk.choices[0].text if content is not None: all_content.append(content) content = "".join(all_content) if all_chunks: d = all_chunks[-1].model_dump() d["choices"] = [{"text": content}] else: d = {"choices": [{"text": content}]} return d def _create_usage_metadata( oai_token_usage: dict, service_tier: Optional[str] = None ) -> UsageMetadata: recognized_service_tier = ( service_tier if service_tier in ["priority", "flex"] else None ) service_tier_prefix = ( f"{recognized_service_tier}_" if recognized_service_tier else "" ) input_tokens = ( oai_token_usage.get("prompt_tokens") or oai_token_usage.get("input_tokens") or 0 ) output_tokens = ( oai_token_usage.get("completion_tokens") or oai_token_usage.get("output_tokens") or 0 ) total_tokens = oai_token_usage.get("total_tokens") or input_tokens + output_tokens input_token_details: dict = { "audio": ( oai_token_usage.get("prompt_tokens_details") or oai_token_usage.get("input_tokens_details") or {} ).get("audio_tokens"), f"{service_tier_prefix}cache_read": ( oai_token_usage.get("prompt_tokens_details") or oai_token_usage.get("input_tokens_details") or {} ).get("cached_tokens"), } output_token_details: dict = { "audio": ( oai_token_usage.get("completion_tokens_details") or oai_token_usage.get("output_tokens_details") or {} ).get("audio_tokens"), f"{service_tier_prefix}reasoning": ( oai_token_usage.get("completion_tokens_details") or oai_token_usage.get("output_tokens_details") or {} ).get("reasoning_tokens"), } if recognized_service_tier: # Avoid counting cache read and reasoning tokens towards the # service tier token count since service tier tokens are already # priced differently input_token_details[recognized_service_tier] = input_tokens - ( input_token_details.get(f"{service_tier_prefix}cache_read") or 0 ) output_token_details[recognized_service_tier] = output_tokens - ( output_token_details.get(f"{service_tier_prefix}reasoning") or 0 ) return UsageMetadata( input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=total_tokens, input_token_details=InputTokenDetails( **{k: v for k, v in input_token_details.items() if v is not None} ), output_token_details=OutputTokenDetails( **{k: v for k, v in output_token_details.items() if v is not None} ), ) def _process_chat_completion(outputs: Any): try: rdict = outputs.model_dump() oai_token_usage = rdict.pop("usage", None) rdict["usage_metadata"] = ( _create_usage_metadata(oai_token_usage, rdict.get("service_tier")) if oai_token_usage else None ) return rdict except BaseException as e: logger.debug(f"Error processing chat completion: {e}") return {"output": outputs} def _get_wrapper( original_create: Callable, name: str, reduce_fn: Callable, tracing_extra: Optional[TracingExtra] = None, invocation_params_fn: Optional[Callable] = None, process_outputs: Optional[Callable] = None, ) -> Callable: textra = tracing_extra or {} @functools.wraps(original_create) def create(*args, **kwargs): decorator = run_helpers.traceable( name=name, run_type="llm", reduce_fn=reduce_fn if kwargs.get("stream") is True else None, process_inputs=_process_inputs, _invocation_params_fn=invocation_params_fn, process_outputs=process_outputs, **textra, ) return decorator(original_create)(*args, **kwargs) @functools.wraps(original_create) async def acreate(*args, **kwargs): decorator = run_helpers.traceable( name=name, run_type="llm", reduce_fn=reduce_fn if kwargs.get("stream") is True else None, process_inputs=_process_inputs, _invocation_params_fn=invocation_params_fn, process_outputs=process_outputs, **textra, ) return await decorator(original_create)(*args, **kwargs) return acreate if run_helpers.is_async(original_create) else create def _get_parse_wrapper( original_parse: Callable, name: str, process_outputs: Callable, tracing_extra: Optional[TracingExtra] = None, invocation_params_fn: Optional[Callable] = None, ) -> Callable: textra = tracing_extra or {} @functools.wraps(original_parse) def parse(*args, **kwargs): decorator = run_helpers.traceable( name=name, run_type="llm", reduce_fn=None, process_inputs=_process_inputs, _invocation_params_fn=invocation_params_fn, process_outputs=process_outputs, **textra, ) return decorator(original_parse)(*args, **kwargs) @functools.wraps(original_parse) async def aparse(*args, **kwargs): decorator = run_helpers.traceable( name=name, run_type="llm", reduce_fn=None, process_inputs=_process_inputs, _invocation_params_fn=invocation_params_fn, process_outputs=process_outputs, **textra, ) return await decorator(original_parse)(*args, **kwargs) return aparse if run_helpers.is_async(original_parse) else parse def _reduce_response_events(events: list[ResponseStreamEvent]) -> dict: for event in events: if event.type == "response.completed": return _process_responses_api_output(event.response) return {} class TracingExtra(TypedDict, total=False): metadata: Optional[Mapping[str, Any]] tags: Optional[list[str]] client: Optional[ls_client.Client] def wrap_openai( client: C, *, tracing_extra: Optional[TracingExtra] = None, chat_name: str = "ChatOpenAI", completions_name: str = "OpenAI", ) -> C: """Patch the OpenAI client to make it traceable. Supports: - Chat and Responses API's - Sync and async OpenAI clients - `create` and `parse` methods - With and without streaming Args: client: The client to patch. tracing_extra: Extra tracing information. chat_name: The run name for the chat completions endpoint. completions_name: The run name for the completions endpoint. Returns: The patched client. Example: ```python import openai from langsmith import wrappers # Use OpenAI client same as you normally would. client = wrappers.wrap_openai(openai.OpenAI()) # Chat API: messages = [ {"role": "system", "content": "You are a helpful assistant."}, { "role": "user", "content": "What physics breakthroughs do you predict will happen by 2300?", }, ] completion = client.chat.completions.create( model="gpt-4o-mini", messages=messages ) print(completion.choices[0].message.content) # Responses API: response = client.responses.create( model="gpt-4o-mini", messages=messages, ) print(response.output_text) ``` !!! warning "Behavior changed in `langsmith` 0.3.16" Support for Responses API added. """ # noqa: E501 tracing_extra = tracing_extra or {} ls_provider = "openai" try: from openai import AsyncAzureOpenAI, AzureOpenAI if isinstance(client, AzureOpenAI) or isinstance(client, AsyncAzureOpenAI): ls_provider = "azure" chat_name = "AzureChatOpenAI" completions_name = "AzureOpenAI" except ImportError: pass # First wrap the create methods - these handle non-streaming cases client.chat.completions.create = _get_wrapper( # type: ignore[method-assign] client.chat.completions.create, chat_name, _reduce_chat, tracing_extra=tracing_extra, invocation_params_fn=functools.partial( _infer_invocation_params, "chat", ls_provider ), process_outputs=_process_chat_completion, ) client.completions.create = _get_wrapper( # type: ignore[method-assign] client.completions.create, completions_name, _reduce_completions, tracing_extra=tracing_extra, invocation_params_fn=functools.partial( _infer_invocation_params, "llm", ls_provider ), ) # Wrap beta.chat.completions.parse if it exists if ( hasattr(client, "beta") and hasattr(client.beta, "chat") and hasattr(client.beta.chat, "completions") and hasattr(client.beta.chat.completions, "parse") ): client.beta.chat.completions.parse = _get_parse_wrapper( # type: ignore[method-assign] client.beta.chat.completions.parse, # type: ignore chat_name, _process_chat_completion, tracing_extra=tracing_extra, invocation_params_fn=functools.partial( _infer_invocation_params, "chat", ls_provider ), ) # Wrap chat.completions.parse if it exists if ( hasattr(client, "chat") and hasattr(client.chat, "completions") and hasattr(client.chat.completions, "parse") ): client.chat.completions.parse = _get_parse_wrapper( # type: ignore[method-assign] client.chat.completions.parse, # type: ignore chat_name, _process_chat_completion, tracing_extra=tracing_extra, invocation_params_fn=functools.partial( _infer_invocation_params, "chat", ls_provider ), ) # For the responses API: "client.responses.create(**kwargs)" if hasattr(client, "responses"): if hasattr(client.responses, "create"): client.responses.create = _get_wrapper( # type: ignore[method-assign] client.responses.create, chat_name, _reduce_response_events, process_outputs=_process_responses_api_output, tracing_extra=tracing_extra, invocation_params_fn=functools.partial( _infer_invocation_params, "chat", ls_provider ), ) if hasattr(client.responses, "parse"): client.responses.parse = _get_parse_wrapper( # type: ignore[method-assign] client.responses.parse, chat_name, _process_responses_api_output, tracing_extra=tracing_extra, invocation_params_fn=functools.partial( _infer_invocation_params, "chat", ls_provider ), ) return client def _process_responses_api_output(response: Any) -> dict: if response: try: output = response.model_dump(exclude_none=True, mode="json") if usage := output.pop("usage", None): output["usage_metadata"] = _create_usage_metadata( usage, output.get("service_tier") ) return output except Exception: return {"output": response} return {}