| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501 |
- from __future__ import annotations
- import functools
- import logging
- from collections.abc import AsyncIterator, Mapping, Sequence
- from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Optional,
- TypeVar,
- Union,
- )
- from pydantic import TypeAdapter
- from typing_extensions import Self, TypedDict
- from langsmith import client as ls_client
- from langsmith import run_helpers
- from langsmith.schemas import InputTokenDetails, UsageMetadata
- if TYPE_CHECKING:
- import httpx
- from anthropic import Anthropic, AsyncAnthropic
- from anthropic.types import Completion, Message, MessageStreamEvent
- C = TypeVar("C", bound=Union["Anthropic", "AsyncAnthropic", Any])
- logger = logging.getLogger(__name__)
- @functools.lru_cache
- def _get_not_given() -> Optional[tuple[type, ...]]:
- try:
- from anthropic._types import NotGiven, Omit
- return (NotGiven, Omit)
- except ImportError:
- return None
- def _strip_not_given(d: dict) -> dict:
- try:
- if not_given := _get_not_given():
- d = {
- k: v
- for k, v in d.items()
- if not any(isinstance(v, t) for t in not_given)
- }
- except Exception as e:
- logger.error(f"Error stripping NotGiven: {e}")
- if "system" in d:
- d["messages"] = [{"role": "system", "content": d["system"]}] + d.get(
- "messages", []
- )
- d.pop("system")
- return {k: v for k, v in d.items() if v is not None}
- def _infer_ls_params(kwargs: dict):
- stripped = _strip_not_given(kwargs)
- stop = stripped.get("stop")
- if stop and isinstance(stop, str):
- stop = [stop]
- # Allowlist of safe invocation parameters to include
- # Only include known, non-sensitive parameters
- allowed_invocation_keys = {
- "mcp_servers",
- "service_tier",
- "top_k",
- "top_p",
- "stream",
- "thinking",
- }
- # Only include allowlisted parameters
- invocation_params = {
- k: v for k, v in stripped.items() if k in allowed_invocation_keys
- }
- return {
- "ls_provider": "anthropic",
- "ls_model_type": "chat",
- "ls_model_name": stripped.get("model", None),
- "ls_temperature": stripped.get("temperature", None),
- "ls_max_tokens": stripped.get("max_tokens", None),
- "ls_stop": stop,
- "ls_invocation_params": invocation_params,
- }
- def _accumulate_event(
- *, event: MessageStreamEvent, current_snapshot: Message | None
- ) -> Message | None:
- try:
- from anthropic.types import ContentBlock
- except ImportError:
- logger.debug("Error importing ContentBlock")
- return current_snapshot
- if current_snapshot is None:
- if event.type == "message_start":
- return event.message
- raise RuntimeError(
- f'Unexpected event order, got {event.type} before "message_start"'
- )
- if event.type == "content_block_start":
- # TODO: check index <-- from anthropic SDK :)
- adapter: TypeAdapter = TypeAdapter(ContentBlock)
- content_block_instance = adapter.validate_python(
- event.content_block.model_dump()
- )
- current_snapshot.content.append(
- content_block_instance, # type: ignore[attr-defined]
- )
- elif event.type == "content_block_delta":
- content = current_snapshot.content[event.index]
- if content.type == "text" and event.delta.type == "text_delta":
- content.text += event.delta.text
- elif event.type == "message_delta":
- current_snapshot.stop_reason = event.delta.stop_reason
- current_snapshot.stop_sequence = event.delta.stop_sequence
- current_snapshot.usage.output_tokens = event.usage.output_tokens
- return current_snapshot
- def _reduce_chat_chunks(all_chunks: Sequence) -> dict:
- full_message = None
- for chunk in all_chunks:
- try:
- full_message = _accumulate_event(event=chunk, current_snapshot=full_message)
- except RuntimeError as e:
- logger.debug(f"Error accumulating event in Anthropic Wrapper: {e}")
- return {"output": all_chunks}
- if full_message is None:
- return {"output": all_chunks}
- d = full_message.model_dump()
- d["usage_metadata"] = _create_usage_metadata(d.pop("usage", {}))
- d.pop("type", None)
- return {"message": d}
- def _create_usage_metadata(anthropic_token_usage: dict) -> UsageMetadata:
- input_tokens = anthropic_token_usage.get("input_tokens") or 0
- output_tokens = anthropic_token_usage.get("output_tokens") or 0
- total_tokens = input_tokens + output_tokens
- input_token_details: dict = {
- "cache_read": anthropic_token_usage.get("cache_creation_input_tokens", 0)
- + anthropic_token_usage.get("cache_read_input_tokens", 0)
- }
- return UsageMetadata(
- input_tokens=input_tokens,
- output_tokens=output_tokens,
- total_tokens=total_tokens,
- input_token_details=InputTokenDetails(
- **{k: v for k, v in input_token_details.items() if v is not None}
- ),
- )
- def _reduce_completions(all_chunks: list[Completion]) -> dict:
- all_content = []
- for chunk in all_chunks:
- content = chunk.completion
- if content is not None:
- all_content.append(content)
- content = "".join(all_content)
- if all_chunks:
- d = all_chunks[-1].model_dump()
- d["choices"] = [{"text": content}]
- else:
- d = {"choices": [{"text": content}]}
- return d
- def _process_chat_completion(outputs: Any):
- try:
- rdict = outputs.model_dump()
- anthropic_token_usage = rdict.pop("usage", None)
- rdict["usage_metadata"] = (
- _create_usage_metadata(anthropic_token_usage)
- if anthropic_token_usage
- else None
- )
- rdict.pop("type", None)
- return {"message": rdict}
- except BaseException as e:
- logger.debug(f"Error processing chat completion: {e}")
- return {"output": outputs}
- def _get_wrapper(
- original_create: Callable,
- name: str,
- reduce_fn: Callable,
- tracing_extra: TracingExtra,
- ) -> Callable:
- @functools.wraps(original_create)
- def create(*args, **kwargs):
- stream = kwargs.get("stream")
- decorator = run_helpers.traceable(
- name=name,
- run_type="llm",
- reduce_fn=reduce_fn if stream else None,
- process_inputs=_strip_not_given,
- process_outputs=_process_chat_completion,
- _invocation_params_fn=_infer_ls_params,
- **tracing_extra,
- )
- result = decorator(original_create)(*args, **kwargs)
- return result
- @functools.wraps(original_create)
- async def acreate(*args, **kwargs):
- stream = kwargs.get("stream")
- decorator = run_helpers.traceable(
- name=name,
- run_type="llm",
- reduce_fn=reduce_fn if stream else None,
- process_inputs=_strip_not_given,
- process_outputs=_process_chat_completion,
- _invocation_params_fn=_infer_ls_params,
- **tracing_extra,
- )
- result = await decorator(original_create)(*args, **kwargs)
- return result
- return acreate if run_helpers.is_async(original_create) else create
- def _get_stream_wrapper(
- original_stream: Callable,
- name: str,
- tracing_extra: TracingExtra,
- ) -> Callable:
- """Create a wrapper for Anthropic's streaming context manager."""
- import anthropic
- is_async = "async" in str(original_stream).lower()
- configured_traceable = run_helpers.traceable(
- name=name,
- reduce_fn=_reduce_chat_chunks,
- run_type="llm",
- process_inputs=_strip_not_given,
- _invocation_params_fn=_infer_ls_params,
- **tracing_extra,
- )
- configured_traceable_text = run_helpers.traceable(
- name=name,
- run_type="llm",
- process_inputs=_strip_not_given,
- process_outputs=_process_chat_completion,
- _invocation_params_fn=_infer_ls_params,
- **tracing_extra,
- )
- if is_async:
- class AsyncMessageStreamWrapper:
- def __init__(
- self,
- wrapped: anthropic.lib.streaming._messages.AsyncMessageStream,
- **kwargs,
- ) -> None:
- self._wrapped = wrapped
- self._kwargs = kwargs
- @property
- def text_stream(self):
- @configured_traceable_text
- async def _text_stream(**_):
- async for chunk in self._wrapped.text_stream:
- yield chunk
- run_tree = run_helpers.get_current_run_tree()
- final_message = await self._wrapped.get_final_message()
- run_tree.outputs = _process_chat_completion(final_message)
- return _text_stream(**self._kwargs)
- @property
- def response(self) -> httpx.Response:
- return self._wrapped.response
- @property
- def request_id(self) -> str | None:
- return self._wrapped.request_id
- async def __anext__(self) -> MessageStreamEvent:
- aiter = self.__aiter__()
- return await aiter.__anext__()
- async def __aiter__(self) -> AsyncIterator[MessageStreamEvent]:
- @configured_traceable
- def traced_iter(**_):
- return self._wrapped.__aiter__()
- async for chunk in traced_iter(**self._kwargs):
- yield chunk
- async def __aenter__(self) -> Self:
- await self._wrapped.__aenter__()
- return self
- async def __aexit__(self, *exc) -> None:
- await self._wrapped.__aexit__(*exc)
- async def close(self) -> None:
- await self._wrapped.close()
- async def get_final_message(self) -> Message:
- return await self._wrapped.get_final_message()
- async def get_final_text(self) -> str:
- return await self._wrapped.get_final_text()
- async def until_done(self) -> None:
- await self._wrapped.until_done()
- @property
- def current_message_snapshot(self) -> Message:
- return self._wrapped.current_message_snapshot
- class AsyncMessagesStreamManagerWrapper:
- def __init__(self, **kwargs):
- self._kwargs = kwargs
- async def __aenter__(self):
- self._manager = original_stream(**self._kwargs)
- stream = await self._manager.__aenter__()
- return AsyncMessageStreamWrapper(stream, **self._kwargs)
- async def __aexit__(self, *exc):
- await self._manager.__aexit__(*exc)
- return AsyncMessagesStreamManagerWrapper
- else:
- class MessageStreamWrapper:
- def __init__(
- self,
- wrapped: anthropic.lib.streaming._messages.MessageStream,
- **kwargs,
- ) -> None:
- self._wrapped = wrapped
- self._kwargs = kwargs
- @property
- def response(self) -> Any:
- return self._wrapped.response
- @property
- def request_id(self) -> str | None:
- return self._wrapped.request_id # type: ignore[no-any-return]
- @property
- def text_stream(self):
- @configured_traceable_text
- def _text_stream(**_):
- yield from self._wrapped.text_stream
- run_tree = run_helpers.get_current_run_tree()
- final_message = self._wrapped.get_final_message()
- run_tree.outputs = _process_chat_completion(final_message)
- return _text_stream(**self._kwargs)
- def __next__(self) -> MessageStreamEvent:
- return self.__iter__().__next__()
- def __iter__(self):
- @configured_traceable
- def traced_iter(**_):
- return self._wrapped.__iter__()
- return traced_iter(**self._kwargs)
- def __enter__(self) -> Self:
- self._wrapped.__enter__()
- return self
- def __exit__(self, *exc) -> None:
- self._wrapped.__exit__(*exc)
- def close(self) -> None:
- self._wrapped.close()
- def get_final_message(self) -> Message:
- return self._wrapped.get_final_message()
- def get_final_text(self) -> str:
- return self._wrapped.get_final_text()
- def until_done(self) -> None:
- return self._wrapped.until_done()
- @property
- def current_message_snapshot(self) -> Message:
- return self._wrapped.current_message_snapshot
- class MessagesStreamManagerWrapper:
- def __init__(self, **kwargs):
- self._kwargs = kwargs
- def __enter__(self):
- self._manager = original_stream(**self._kwargs)
- return MessageStreamWrapper(self._manager.__enter__(), **self._kwargs)
- def __exit__(self, *exc):
- self._manager.__exit__(*exc)
- return MessagesStreamManagerWrapper
- class TracingExtra(TypedDict, total=False):
- metadata: Optional[Mapping[str, Any]]
- tags: Optional[list[str]]
- client: Optional[ls_client.Client]
- def wrap_anthropic(client: C, *, tracing_extra: Optional[TracingExtra] = None) -> C:
- """Patch the Anthropic client to make it traceable.
- Args:
- client: The client to patch.
- tracing_extra: Extra tracing information.
- Returns:
- The patched client.
- Example:
- ```python
- import anthropic
- from langsmith import wrappers
- client = wrappers.wrap_anthropic(anthropic.Anthropic())
- # Use Anthropic client same as you normally would:
- system = "You are a helpful assistant."
- messages = [
- {
- "role": "user",
- "content": "What physics breakthroughs do you predict will happen by 2300?",
- }
- ]
- completion = client.messages.create(
- model="claude-3-5-sonnet-latest",
- messages=messages,
- max_tokens=1000,
- system=system,
- )
- print(completion.content)
- # You can also use the streaming context manager:
- with client.messages.stream(
- model="claude-3-5-sonnet-latest",
- messages=messages,
- max_tokens=1000,
- system=system,
- ) as stream:
- for text in stream.text_stream:
- print(text, end="", flush=True)
- message = stream.get_final_message()
- ```
- """ # noqa: E501
- tracing_extra = tracing_extra or {}
- client.messages.create = _get_wrapper( # type: ignore[method-assign]
- client.messages.create,
- "ChatAnthropic",
- _reduce_chat_chunks,
- tracing_extra,
- )
- client.messages.stream = _get_stream_wrapper( # type: ignore[method-assign]
- client.messages.stream,
- "ChatAnthropic",
- tracing_extra,
- )
- client.completions.create = _get_wrapper( # type: ignore[method-assign]
- client.completions.create,
- "Anthropic",
- _reduce_completions,
- tracing_extra,
- )
- if (
- hasattr(client, "beta")
- and hasattr(client.beta, "messages")
- and hasattr(client.beta.messages, "create")
- ):
- client.beta.messages.create = _get_wrapper( # type: ignore[method-assign]
- client.beta.messages.create, # type: ignore
- "ChatAnthropic",
- _reduce_chat_chunks,
- tracing_extra,
- )
- return client
|