_openai.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. from __future__ import annotations
  2. import functools
  3. import logging
  4. from collections import defaultdict
  5. from collections.abc import Mapping
  6. from typing import (
  7. TYPE_CHECKING,
  8. Any,
  9. Callable,
  10. Optional,
  11. TypeVar,
  12. Union,
  13. )
  14. from typing_extensions import TypedDict
  15. from langsmith import client as ls_client
  16. from langsmith import run_helpers
  17. from langsmith.schemas import InputTokenDetails, OutputTokenDetails, UsageMetadata
  18. if TYPE_CHECKING:
  19. from openai import AsyncOpenAI, OpenAI
  20. from openai.types.chat.chat_completion_chunk import (
  21. ChatCompletionChunk,
  22. Choice,
  23. ChoiceDeltaToolCall,
  24. )
  25. from openai.types.completion import Completion
  26. from openai.types.responses import ResponseStreamEvent # type: ignore
  27. # Any is used since it may work with Azure or other providers
  28. C = TypeVar("C", bound=Union["OpenAI", "AsyncOpenAI", Any])
  29. logger = logging.getLogger(__name__)
  30. @functools.lru_cache
  31. def _get_omit_types() -> tuple[type, ...]:
  32. """Get NotGiven/Omit sentinel types used by OpenAI SDK."""
  33. types: list[type[Any]] = []
  34. try:
  35. from openai._types import NotGiven, Omit
  36. types.append(NotGiven)
  37. types.append(Omit)
  38. except ImportError:
  39. pass
  40. return tuple(types)
  41. def _strip_not_given(d: dict) -> dict:
  42. try:
  43. omit_types = _get_omit_types()
  44. if not omit_types:
  45. return d
  46. return {
  47. k: v
  48. for k, v in d.items()
  49. if not (isinstance(v, omit_types) or (k.startswith("extra_") and v is None))
  50. }
  51. except Exception as e:
  52. logger.error(f"Error stripping NotGiven: {e}")
  53. return d
  54. def _process_inputs(d: dict) -> dict:
  55. """Strip `NotGiven` values and serialize `text_format` to JSON schema."""
  56. d = _strip_not_given(d)
  57. # Convert text_format (Pydantic model) to JSON schema if present
  58. if "text_format" in d:
  59. text_format = d["text_format"]
  60. if hasattr(text_format, "model_json_schema"):
  61. try:
  62. return {
  63. **d,
  64. "text_format": text_format.model_json_schema(),
  65. }
  66. except Exception:
  67. pass
  68. return d
  69. def _infer_invocation_params(model_type: str, provider: str, kwargs: dict):
  70. stripped = _strip_not_given(kwargs)
  71. stop = stripped.get("stop")
  72. if stop and isinstance(stop, str):
  73. stop = [stop]
  74. # Allowlist of safe invocation parameters to include
  75. # Only include known, non-sensitive parameters
  76. allowed_invocation_keys = {
  77. "frequency_penalty",
  78. "n",
  79. "logit_bias",
  80. "logprobs",
  81. "modalities",
  82. "parallel_tool_calls",
  83. "prediction",
  84. "presence_penalty",
  85. "prompt_cache_key",
  86. "reasoning",
  87. "reasoning_effort",
  88. "response_format",
  89. "seed",
  90. "service_tier",
  91. "stream_options",
  92. "top_logprobs",
  93. "top_p",
  94. "truncation",
  95. "user",
  96. "verbosity",
  97. "web_search_options",
  98. }
  99. # Only include allowlisted parameters
  100. invocation_params = {
  101. k: v for k, v in stripped.items() if k in allowed_invocation_keys
  102. }
  103. return {
  104. "ls_provider": provider,
  105. "ls_model_type": model_type,
  106. "ls_model_name": stripped.get("model"),
  107. "ls_temperature": stripped.get("temperature"),
  108. "ls_max_tokens": stripped.get("max_tokens")
  109. or stripped.get("max_completion_tokens")
  110. or stripped.get("max_output_tokens"),
  111. "ls_stop": stop,
  112. "ls_invocation_params": invocation_params,
  113. }
  114. def _reduce_choices(choices: list[Choice]) -> dict:
  115. reversed_choices = list(reversed(choices))
  116. message: dict[str, Any] = {
  117. "role": "assistant",
  118. "content": "",
  119. }
  120. for c in reversed_choices:
  121. if hasattr(c, "delta") and getattr(c.delta, "role", None):
  122. message["role"] = c.delta.role
  123. break
  124. tool_calls: defaultdict[int, list[ChoiceDeltaToolCall]] = defaultdict(list)
  125. for c in choices:
  126. if hasattr(c, "delta"):
  127. if getattr(c.delta, "content", None):
  128. message["content"] += c.delta.content
  129. if getattr(c.delta, "function_call", None):
  130. if not message.get("function_call"):
  131. message["function_call"] = {"name": "", "arguments": ""}
  132. name_ = getattr(c.delta.function_call, "name", None)
  133. if name_:
  134. message["function_call"]["name"] += name_
  135. arguments_ = getattr(c.delta.function_call, "arguments", None)
  136. if arguments_:
  137. message["function_call"]["arguments"] += arguments_
  138. if getattr(c.delta, "tool_calls", None):
  139. tool_calls_list = c.delta.tool_calls
  140. if tool_calls_list is not None:
  141. for tool_call in tool_calls_list:
  142. tool_calls[tool_call.index].append(tool_call)
  143. if tool_calls:
  144. message["tool_calls"] = [None for _ in range(max(tool_calls.keys()) + 1)]
  145. for index, tool_call_chunks in tool_calls.items():
  146. message["tool_calls"][index] = {
  147. "index": index,
  148. "id": next((c.id for c in tool_call_chunks if c.id), None),
  149. "type": next((c.type for c in tool_call_chunks if c.type), None),
  150. "function": {"name": "", "arguments": ""},
  151. }
  152. for chunk in tool_call_chunks:
  153. if getattr(chunk, "function", None):
  154. name_ = getattr(chunk.function, "name", None)
  155. if name_:
  156. message["tool_calls"][index]["function"]["name"] += name_
  157. arguments_ = getattr(chunk.function, "arguments", None)
  158. if arguments_:
  159. message["tool_calls"][index]["function"]["arguments"] += (
  160. arguments_
  161. )
  162. return {
  163. "index": getattr(choices[0], "index", 0) if choices else 0,
  164. "finish_reason": next(
  165. (
  166. c.finish_reason
  167. for c in reversed_choices
  168. if getattr(c, "finish_reason", None)
  169. ),
  170. None,
  171. ),
  172. "message": message,
  173. }
  174. def _reduce_chat(all_chunks: list[ChatCompletionChunk]) -> dict:
  175. choices_by_index: defaultdict[int, list[Choice]] = defaultdict(list)
  176. for chunk in all_chunks:
  177. for choice in chunk.choices:
  178. choices_by_index[choice.index].append(choice)
  179. if all_chunks:
  180. d = all_chunks[-1].model_dump()
  181. d["choices"] = [
  182. _reduce_choices(choices) for choices in choices_by_index.values()
  183. ]
  184. else:
  185. d = {"choices": [{"message": {"role": "assistant", "content": ""}}]}
  186. # streamed outputs don't go through `process_outputs`
  187. # so we need to flatten metadata here
  188. oai_token_usage = d.pop("usage", None)
  189. d["usage_metadata"] = (
  190. _create_usage_metadata(oai_token_usage) if oai_token_usage else None
  191. )
  192. return d
  193. def _reduce_completions(all_chunks: list[Completion]) -> dict:
  194. all_content = []
  195. for chunk in all_chunks:
  196. content = chunk.choices[0].text
  197. if content is not None:
  198. all_content.append(content)
  199. content = "".join(all_content)
  200. if all_chunks:
  201. d = all_chunks[-1].model_dump()
  202. d["choices"] = [{"text": content}]
  203. else:
  204. d = {"choices": [{"text": content}]}
  205. return d
  206. def _create_usage_metadata(
  207. oai_token_usage: dict, service_tier: Optional[str] = None
  208. ) -> UsageMetadata:
  209. recognized_service_tier = (
  210. service_tier if service_tier in ["priority", "flex"] else None
  211. )
  212. service_tier_prefix = (
  213. f"{recognized_service_tier}_" if recognized_service_tier else ""
  214. )
  215. input_tokens = (
  216. oai_token_usage.get("prompt_tokens") or oai_token_usage.get("input_tokens") or 0
  217. )
  218. output_tokens = (
  219. oai_token_usage.get("completion_tokens")
  220. or oai_token_usage.get("output_tokens")
  221. or 0
  222. )
  223. total_tokens = oai_token_usage.get("total_tokens") or input_tokens + output_tokens
  224. input_token_details: dict = {
  225. "audio": (
  226. oai_token_usage.get("prompt_tokens_details")
  227. or oai_token_usage.get("input_tokens_details")
  228. or {}
  229. ).get("audio_tokens"),
  230. f"{service_tier_prefix}cache_read": (
  231. oai_token_usage.get("prompt_tokens_details")
  232. or oai_token_usage.get("input_tokens_details")
  233. or {}
  234. ).get("cached_tokens"),
  235. }
  236. output_token_details: dict = {
  237. "audio": (
  238. oai_token_usage.get("completion_tokens_details")
  239. or oai_token_usage.get("output_tokens_details")
  240. or {}
  241. ).get("audio_tokens"),
  242. f"{service_tier_prefix}reasoning": (
  243. oai_token_usage.get("completion_tokens_details")
  244. or oai_token_usage.get("output_tokens_details")
  245. or {}
  246. ).get("reasoning_tokens"),
  247. }
  248. if recognized_service_tier:
  249. # Avoid counting cache read and reasoning tokens towards the
  250. # service tier token count since service tier tokens are already
  251. # priced differently
  252. input_token_details[recognized_service_tier] = input_tokens - (
  253. input_token_details.get(f"{service_tier_prefix}cache_read") or 0
  254. )
  255. output_token_details[recognized_service_tier] = output_tokens - (
  256. output_token_details.get(f"{service_tier_prefix}reasoning") or 0
  257. )
  258. return UsageMetadata(
  259. input_tokens=input_tokens,
  260. output_tokens=output_tokens,
  261. total_tokens=total_tokens,
  262. input_token_details=InputTokenDetails(
  263. **{k: v for k, v in input_token_details.items() if v is not None}
  264. ),
  265. output_token_details=OutputTokenDetails(
  266. **{k: v for k, v in output_token_details.items() if v is not None}
  267. ),
  268. )
  269. def _process_chat_completion(outputs: Any):
  270. try:
  271. rdict = outputs.model_dump()
  272. oai_token_usage = rdict.pop("usage", None)
  273. rdict["usage_metadata"] = (
  274. _create_usage_metadata(oai_token_usage, rdict.get("service_tier"))
  275. if oai_token_usage
  276. else None
  277. )
  278. return rdict
  279. except BaseException as e:
  280. logger.debug(f"Error processing chat completion: {e}")
  281. return {"output": outputs}
  282. def _get_wrapper(
  283. original_create: Callable,
  284. name: str,
  285. reduce_fn: Callable,
  286. tracing_extra: Optional[TracingExtra] = None,
  287. invocation_params_fn: Optional[Callable] = None,
  288. process_outputs: Optional[Callable] = None,
  289. ) -> Callable:
  290. textra = tracing_extra or {}
  291. @functools.wraps(original_create)
  292. def create(*args, **kwargs):
  293. decorator = run_helpers.traceable(
  294. name=name,
  295. run_type="llm",
  296. reduce_fn=reduce_fn if kwargs.get("stream") is True else None,
  297. process_inputs=_process_inputs,
  298. _invocation_params_fn=invocation_params_fn,
  299. process_outputs=process_outputs,
  300. **textra,
  301. )
  302. return decorator(original_create)(*args, **kwargs)
  303. @functools.wraps(original_create)
  304. async def acreate(*args, **kwargs):
  305. decorator = run_helpers.traceable(
  306. name=name,
  307. run_type="llm",
  308. reduce_fn=reduce_fn if kwargs.get("stream") is True else None,
  309. process_inputs=_process_inputs,
  310. _invocation_params_fn=invocation_params_fn,
  311. process_outputs=process_outputs,
  312. **textra,
  313. )
  314. return await decorator(original_create)(*args, **kwargs)
  315. return acreate if run_helpers.is_async(original_create) else create
  316. def _get_parse_wrapper(
  317. original_parse: Callable,
  318. name: str,
  319. process_outputs: Callable,
  320. tracing_extra: Optional[TracingExtra] = None,
  321. invocation_params_fn: Optional[Callable] = None,
  322. ) -> Callable:
  323. textra = tracing_extra or {}
  324. @functools.wraps(original_parse)
  325. def parse(*args, **kwargs):
  326. decorator = run_helpers.traceable(
  327. name=name,
  328. run_type="llm",
  329. reduce_fn=None,
  330. process_inputs=_process_inputs,
  331. _invocation_params_fn=invocation_params_fn,
  332. process_outputs=process_outputs,
  333. **textra,
  334. )
  335. return decorator(original_parse)(*args, **kwargs)
  336. @functools.wraps(original_parse)
  337. async def aparse(*args, **kwargs):
  338. decorator = run_helpers.traceable(
  339. name=name,
  340. run_type="llm",
  341. reduce_fn=None,
  342. process_inputs=_process_inputs,
  343. _invocation_params_fn=invocation_params_fn,
  344. process_outputs=process_outputs,
  345. **textra,
  346. )
  347. return await decorator(original_parse)(*args, **kwargs)
  348. return aparse if run_helpers.is_async(original_parse) else parse
  349. def _reduce_response_events(events: list[ResponseStreamEvent]) -> dict:
  350. for event in events:
  351. if event.type == "response.completed":
  352. return _process_responses_api_output(event.response)
  353. return {}
  354. class TracingExtra(TypedDict, total=False):
  355. metadata: Optional[Mapping[str, Any]]
  356. tags: Optional[list[str]]
  357. client: Optional[ls_client.Client]
  358. def wrap_openai(
  359. client: C,
  360. *,
  361. tracing_extra: Optional[TracingExtra] = None,
  362. chat_name: str = "ChatOpenAI",
  363. completions_name: str = "OpenAI",
  364. ) -> C:
  365. """Patch the OpenAI client to make it traceable.
  366. Supports:
  367. - Chat and Responses API's
  368. - Sync and async OpenAI clients
  369. - `create` and `parse` methods
  370. - With and without streaming
  371. Args:
  372. client: The client to patch.
  373. tracing_extra: Extra tracing information.
  374. chat_name: The run name for the chat completions endpoint.
  375. completions_name: The run name for the completions endpoint.
  376. Returns:
  377. The patched client.
  378. Example:
  379. ```python
  380. import openai
  381. from langsmith import wrappers
  382. # Use OpenAI client same as you normally would.
  383. client = wrappers.wrap_openai(openai.OpenAI())
  384. # Chat API:
  385. messages = [
  386. {"role": "system", "content": "You are a helpful assistant."},
  387. {
  388. "role": "user",
  389. "content": "What physics breakthroughs do you predict will happen by 2300?",
  390. },
  391. ]
  392. completion = client.chat.completions.create(
  393. model="gpt-4o-mini", messages=messages
  394. )
  395. print(completion.choices[0].message.content)
  396. # Responses API:
  397. response = client.responses.create(
  398. model="gpt-4o-mini",
  399. messages=messages,
  400. )
  401. print(response.output_text)
  402. ```
  403. !!! warning "Behavior changed in `langsmith` 0.3.16"
  404. Support for Responses API added.
  405. """ # noqa: E501
  406. tracing_extra = tracing_extra or {}
  407. ls_provider = "openai"
  408. try:
  409. from openai import AsyncAzureOpenAI, AzureOpenAI
  410. if isinstance(client, AzureOpenAI) or isinstance(client, AsyncAzureOpenAI):
  411. ls_provider = "azure"
  412. chat_name = "AzureChatOpenAI"
  413. completions_name = "AzureOpenAI"
  414. except ImportError:
  415. pass
  416. # First wrap the create methods - these handle non-streaming cases
  417. client.chat.completions.create = _get_wrapper( # type: ignore[method-assign]
  418. client.chat.completions.create,
  419. chat_name,
  420. _reduce_chat,
  421. tracing_extra=tracing_extra,
  422. invocation_params_fn=functools.partial(
  423. _infer_invocation_params, "chat", ls_provider
  424. ),
  425. process_outputs=_process_chat_completion,
  426. )
  427. client.completions.create = _get_wrapper( # type: ignore[method-assign]
  428. client.completions.create,
  429. completions_name,
  430. _reduce_completions,
  431. tracing_extra=tracing_extra,
  432. invocation_params_fn=functools.partial(
  433. _infer_invocation_params, "llm", ls_provider
  434. ),
  435. )
  436. # Wrap beta.chat.completions.parse if it exists
  437. if (
  438. hasattr(client, "beta")
  439. and hasattr(client.beta, "chat")
  440. and hasattr(client.beta.chat, "completions")
  441. and hasattr(client.beta.chat.completions, "parse")
  442. ):
  443. client.beta.chat.completions.parse = _get_parse_wrapper( # type: ignore[method-assign]
  444. client.beta.chat.completions.parse, # type: ignore
  445. chat_name,
  446. _process_chat_completion,
  447. tracing_extra=tracing_extra,
  448. invocation_params_fn=functools.partial(
  449. _infer_invocation_params, "chat", ls_provider
  450. ),
  451. )
  452. # Wrap chat.completions.parse if it exists
  453. if (
  454. hasattr(client, "chat")
  455. and hasattr(client.chat, "completions")
  456. and hasattr(client.chat.completions, "parse")
  457. ):
  458. client.chat.completions.parse = _get_parse_wrapper( # type: ignore[method-assign]
  459. client.chat.completions.parse, # type: ignore
  460. chat_name,
  461. _process_chat_completion,
  462. tracing_extra=tracing_extra,
  463. invocation_params_fn=functools.partial(
  464. _infer_invocation_params, "chat", ls_provider
  465. ),
  466. )
  467. # For the responses API: "client.responses.create(**kwargs)"
  468. if hasattr(client, "responses"):
  469. if hasattr(client.responses, "create"):
  470. client.responses.create = _get_wrapper( # type: ignore[method-assign]
  471. client.responses.create,
  472. chat_name,
  473. _reduce_response_events,
  474. process_outputs=_process_responses_api_output,
  475. tracing_extra=tracing_extra,
  476. invocation_params_fn=functools.partial(
  477. _infer_invocation_params, "chat", ls_provider
  478. ),
  479. )
  480. if hasattr(client.responses, "parse"):
  481. client.responses.parse = _get_parse_wrapper( # type: ignore[method-assign]
  482. client.responses.parse,
  483. chat_name,
  484. _process_responses_api_output,
  485. tracing_extra=tracing_extra,
  486. invocation_params_fn=functools.partial(
  487. _infer_invocation_params, "chat", ls_provider
  488. ),
  489. )
  490. return client
  491. def _process_responses_api_output(response: Any) -> dict:
  492. if response:
  493. try:
  494. output = response.model_dump(exclude_none=True, mode="json")
  495. if usage := output.pop("usage", None):
  496. output["usage_metadata"] = _create_usage_metadata(
  497. usage, output.get("service_tier")
  498. )
  499. return output
  500. except Exception:
  501. return {"output": response}
  502. return {}