_anthropic.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. from __future__ import annotations
  2. import functools
  3. import logging
  4. from collections.abc import AsyncIterator, Mapping, Sequence
  5. from typing import (
  6. TYPE_CHECKING,
  7. Any,
  8. Callable,
  9. Optional,
  10. TypeVar,
  11. Union,
  12. )
  13. from pydantic import TypeAdapter
  14. from typing_extensions import Self, TypedDict
  15. from langsmith import client as ls_client
  16. from langsmith import run_helpers
  17. from langsmith.schemas import InputTokenDetails, UsageMetadata
  18. if TYPE_CHECKING:
  19. import httpx
  20. from anthropic import Anthropic, AsyncAnthropic
  21. from anthropic.types import Completion, Message, MessageStreamEvent
  22. C = TypeVar("C", bound=Union["Anthropic", "AsyncAnthropic", Any])
  23. logger = logging.getLogger(__name__)
  24. @functools.lru_cache
  25. def _get_not_given() -> Optional[tuple[type, ...]]:
  26. try:
  27. from anthropic._types import NotGiven, Omit
  28. return (NotGiven, Omit)
  29. except ImportError:
  30. return None
  31. def _strip_not_given(d: dict) -> dict:
  32. try:
  33. if not_given := _get_not_given():
  34. d = {
  35. k: v
  36. for k, v in d.items()
  37. if not any(isinstance(v, t) for t in not_given)
  38. }
  39. except Exception as e:
  40. logger.error(f"Error stripping NotGiven: {e}")
  41. if "system" in d:
  42. d["messages"] = [{"role": "system", "content": d["system"]}] + d.get(
  43. "messages", []
  44. )
  45. d.pop("system")
  46. return {k: v for k, v in d.items() if v is not None}
  47. def _infer_ls_params(kwargs: dict):
  48. stripped = _strip_not_given(kwargs)
  49. stop = stripped.get("stop")
  50. if stop and isinstance(stop, str):
  51. stop = [stop]
  52. # Allowlist of safe invocation parameters to include
  53. # Only include known, non-sensitive parameters
  54. allowed_invocation_keys = {
  55. "mcp_servers",
  56. "service_tier",
  57. "top_k",
  58. "top_p",
  59. "stream",
  60. "thinking",
  61. }
  62. # Only include allowlisted parameters
  63. invocation_params = {
  64. k: v for k, v in stripped.items() if k in allowed_invocation_keys
  65. }
  66. return {
  67. "ls_provider": "anthropic",
  68. "ls_model_type": "chat",
  69. "ls_model_name": stripped.get("model", None),
  70. "ls_temperature": stripped.get("temperature", None),
  71. "ls_max_tokens": stripped.get("max_tokens", None),
  72. "ls_stop": stop,
  73. "ls_invocation_params": invocation_params,
  74. }
  75. def _accumulate_event(
  76. *, event: MessageStreamEvent, current_snapshot: Message | None
  77. ) -> Message | None:
  78. try:
  79. from anthropic.types import ContentBlock
  80. except ImportError:
  81. logger.debug("Error importing ContentBlock")
  82. return current_snapshot
  83. if current_snapshot is None:
  84. if event.type == "message_start":
  85. return event.message
  86. raise RuntimeError(
  87. f'Unexpected event order, got {event.type} before "message_start"'
  88. )
  89. if event.type == "content_block_start":
  90. # TODO: check index <-- from anthropic SDK :)
  91. adapter: TypeAdapter = TypeAdapter(ContentBlock)
  92. content_block_instance = adapter.validate_python(
  93. event.content_block.model_dump()
  94. )
  95. current_snapshot.content.append(
  96. content_block_instance, # type: ignore[attr-defined]
  97. )
  98. elif event.type == "content_block_delta":
  99. content = current_snapshot.content[event.index]
  100. if content.type == "text" and event.delta.type == "text_delta":
  101. content.text += event.delta.text
  102. elif event.type == "message_delta":
  103. current_snapshot.stop_reason = event.delta.stop_reason
  104. current_snapshot.stop_sequence = event.delta.stop_sequence
  105. current_snapshot.usage.output_tokens = event.usage.output_tokens
  106. return current_snapshot
  107. def _reduce_chat_chunks(all_chunks: Sequence) -> dict:
  108. full_message = None
  109. for chunk in all_chunks:
  110. try:
  111. full_message = _accumulate_event(event=chunk, current_snapshot=full_message)
  112. except RuntimeError as e:
  113. logger.debug(f"Error accumulating event in Anthropic Wrapper: {e}")
  114. return {"output": all_chunks}
  115. if full_message is None:
  116. return {"output": all_chunks}
  117. d = full_message.model_dump()
  118. d["usage_metadata"] = _create_usage_metadata(d.pop("usage", {}))
  119. d.pop("type", None)
  120. return {"message": d}
  121. def _create_usage_metadata(anthropic_token_usage: dict) -> UsageMetadata:
  122. input_tokens = anthropic_token_usage.get("input_tokens") or 0
  123. output_tokens = anthropic_token_usage.get("output_tokens") or 0
  124. total_tokens = input_tokens + output_tokens
  125. input_token_details: dict = {
  126. "cache_read": anthropic_token_usage.get("cache_creation_input_tokens", 0)
  127. + anthropic_token_usage.get("cache_read_input_tokens", 0)
  128. }
  129. return UsageMetadata(
  130. input_tokens=input_tokens,
  131. output_tokens=output_tokens,
  132. total_tokens=total_tokens,
  133. input_token_details=InputTokenDetails(
  134. **{k: v for k, v in input_token_details.items() if v is not None}
  135. ),
  136. )
  137. def _reduce_completions(all_chunks: list[Completion]) -> dict:
  138. all_content = []
  139. for chunk in all_chunks:
  140. content = chunk.completion
  141. if content is not None:
  142. all_content.append(content)
  143. content = "".join(all_content)
  144. if all_chunks:
  145. d = all_chunks[-1].model_dump()
  146. d["choices"] = [{"text": content}]
  147. else:
  148. d = {"choices": [{"text": content}]}
  149. return d
  150. def _process_chat_completion(outputs: Any):
  151. try:
  152. rdict = outputs.model_dump()
  153. anthropic_token_usage = rdict.pop("usage", None)
  154. rdict["usage_metadata"] = (
  155. _create_usage_metadata(anthropic_token_usage)
  156. if anthropic_token_usage
  157. else None
  158. )
  159. rdict.pop("type", None)
  160. return {"message": rdict}
  161. except BaseException as e:
  162. logger.debug(f"Error processing chat completion: {e}")
  163. return {"output": outputs}
  164. def _get_wrapper(
  165. original_create: Callable,
  166. name: str,
  167. reduce_fn: Callable,
  168. tracing_extra: TracingExtra,
  169. ) -> Callable:
  170. @functools.wraps(original_create)
  171. def create(*args, **kwargs):
  172. stream = kwargs.get("stream")
  173. decorator = run_helpers.traceable(
  174. name=name,
  175. run_type="llm",
  176. reduce_fn=reduce_fn if stream else None,
  177. process_inputs=_strip_not_given,
  178. process_outputs=_process_chat_completion,
  179. _invocation_params_fn=_infer_ls_params,
  180. **tracing_extra,
  181. )
  182. result = decorator(original_create)(*args, **kwargs)
  183. return result
  184. @functools.wraps(original_create)
  185. async def acreate(*args, **kwargs):
  186. stream = kwargs.get("stream")
  187. decorator = run_helpers.traceable(
  188. name=name,
  189. run_type="llm",
  190. reduce_fn=reduce_fn if stream else None,
  191. process_inputs=_strip_not_given,
  192. process_outputs=_process_chat_completion,
  193. _invocation_params_fn=_infer_ls_params,
  194. **tracing_extra,
  195. )
  196. result = await decorator(original_create)(*args, **kwargs)
  197. return result
  198. return acreate if run_helpers.is_async(original_create) else create
  199. def _get_stream_wrapper(
  200. original_stream: Callable,
  201. name: str,
  202. tracing_extra: TracingExtra,
  203. ) -> Callable:
  204. """Create a wrapper for Anthropic's streaming context manager."""
  205. import anthropic
  206. is_async = "async" in str(original_stream).lower()
  207. configured_traceable = run_helpers.traceable(
  208. name=name,
  209. reduce_fn=_reduce_chat_chunks,
  210. run_type="llm",
  211. process_inputs=_strip_not_given,
  212. _invocation_params_fn=_infer_ls_params,
  213. **tracing_extra,
  214. )
  215. configured_traceable_text = run_helpers.traceable(
  216. name=name,
  217. run_type="llm",
  218. process_inputs=_strip_not_given,
  219. process_outputs=_process_chat_completion,
  220. _invocation_params_fn=_infer_ls_params,
  221. **tracing_extra,
  222. )
  223. if is_async:
  224. class AsyncMessageStreamWrapper:
  225. def __init__(
  226. self,
  227. wrapped: anthropic.lib.streaming._messages.AsyncMessageStream,
  228. **kwargs,
  229. ) -> None:
  230. self._wrapped = wrapped
  231. self._kwargs = kwargs
  232. @property
  233. def text_stream(self):
  234. @configured_traceable_text
  235. async def _text_stream(**_):
  236. async for chunk in self._wrapped.text_stream:
  237. yield chunk
  238. run_tree = run_helpers.get_current_run_tree()
  239. final_message = await self._wrapped.get_final_message()
  240. run_tree.outputs = _process_chat_completion(final_message)
  241. return _text_stream(**self._kwargs)
  242. @property
  243. def response(self) -> httpx.Response:
  244. return self._wrapped.response
  245. @property
  246. def request_id(self) -> str | None:
  247. return self._wrapped.request_id
  248. async def __anext__(self) -> MessageStreamEvent:
  249. aiter = self.__aiter__()
  250. return await aiter.__anext__()
  251. async def __aiter__(self) -> AsyncIterator[MessageStreamEvent]:
  252. @configured_traceable
  253. def traced_iter(**_):
  254. return self._wrapped.__aiter__()
  255. async for chunk in traced_iter(**self._kwargs):
  256. yield chunk
  257. async def __aenter__(self) -> Self:
  258. await self._wrapped.__aenter__()
  259. return self
  260. async def __aexit__(self, *exc) -> None:
  261. await self._wrapped.__aexit__(*exc)
  262. async def close(self) -> None:
  263. await self._wrapped.close()
  264. async def get_final_message(self) -> Message:
  265. return await self._wrapped.get_final_message()
  266. async def get_final_text(self) -> str:
  267. return await self._wrapped.get_final_text()
  268. async def until_done(self) -> None:
  269. await self._wrapped.until_done()
  270. @property
  271. def current_message_snapshot(self) -> Message:
  272. return self._wrapped.current_message_snapshot
  273. class AsyncMessagesStreamManagerWrapper:
  274. def __init__(self, **kwargs):
  275. self._kwargs = kwargs
  276. async def __aenter__(self):
  277. self._manager = original_stream(**self._kwargs)
  278. stream = await self._manager.__aenter__()
  279. return AsyncMessageStreamWrapper(stream, **self._kwargs)
  280. async def __aexit__(self, *exc):
  281. await self._manager.__aexit__(*exc)
  282. return AsyncMessagesStreamManagerWrapper
  283. else:
  284. class MessageStreamWrapper:
  285. def __init__(
  286. self,
  287. wrapped: anthropic.lib.streaming._messages.MessageStream,
  288. **kwargs,
  289. ) -> None:
  290. self._wrapped = wrapped
  291. self._kwargs = kwargs
  292. @property
  293. def response(self) -> Any:
  294. return self._wrapped.response
  295. @property
  296. def request_id(self) -> str | None:
  297. return self._wrapped.request_id # type: ignore[no-any-return]
  298. @property
  299. def text_stream(self):
  300. @configured_traceable_text
  301. def _text_stream(**_):
  302. yield from self._wrapped.text_stream
  303. run_tree = run_helpers.get_current_run_tree()
  304. final_message = self._wrapped.get_final_message()
  305. run_tree.outputs = _process_chat_completion(final_message)
  306. return _text_stream(**self._kwargs)
  307. def __next__(self) -> MessageStreamEvent:
  308. return self.__iter__().__next__()
  309. def __iter__(self):
  310. @configured_traceable
  311. def traced_iter(**_):
  312. return self._wrapped.__iter__()
  313. return traced_iter(**self._kwargs)
  314. def __enter__(self) -> Self:
  315. self._wrapped.__enter__()
  316. return self
  317. def __exit__(self, *exc) -> None:
  318. self._wrapped.__exit__(*exc)
  319. def close(self) -> None:
  320. self._wrapped.close()
  321. def get_final_message(self) -> Message:
  322. return self._wrapped.get_final_message()
  323. def get_final_text(self) -> str:
  324. return self._wrapped.get_final_text()
  325. def until_done(self) -> None:
  326. return self._wrapped.until_done()
  327. @property
  328. def current_message_snapshot(self) -> Message:
  329. return self._wrapped.current_message_snapshot
  330. class MessagesStreamManagerWrapper:
  331. def __init__(self, **kwargs):
  332. self._kwargs = kwargs
  333. def __enter__(self):
  334. self._manager = original_stream(**self._kwargs)
  335. return MessageStreamWrapper(self._manager.__enter__(), **self._kwargs)
  336. def __exit__(self, *exc):
  337. self._manager.__exit__(*exc)
  338. return MessagesStreamManagerWrapper
  339. class TracingExtra(TypedDict, total=False):
  340. metadata: Optional[Mapping[str, Any]]
  341. tags: Optional[list[str]]
  342. client: Optional[ls_client.Client]
  343. def wrap_anthropic(client: C, *, tracing_extra: Optional[TracingExtra] = None) -> C:
  344. """Patch the Anthropic client to make it traceable.
  345. Args:
  346. client: The client to patch.
  347. tracing_extra: Extra tracing information.
  348. Returns:
  349. The patched client.
  350. Example:
  351. ```python
  352. import anthropic
  353. from langsmith import wrappers
  354. client = wrappers.wrap_anthropic(anthropic.Anthropic())
  355. # Use Anthropic client same as you normally would:
  356. system = "You are a helpful assistant."
  357. messages = [
  358. {
  359. "role": "user",
  360. "content": "What physics breakthroughs do you predict will happen by 2300?",
  361. }
  362. ]
  363. completion = client.messages.create(
  364. model="claude-3-5-sonnet-latest",
  365. messages=messages,
  366. max_tokens=1000,
  367. system=system,
  368. )
  369. print(completion.content)
  370. # You can also use the streaming context manager:
  371. with client.messages.stream(
  372. model="claude-3-5-sonnet-latest",
  373. messages=messages,
  374. max_tokens=1000,
  375. system=system,
  376. ) as stream:
  377. for text in stream.text_stream:
  378. print(text, end="", flush=True)
  379. message = stream.get_final_message()
  380. ```
  381. """ # noqa: E501
  382. tracing_extra = tracing_extra or {}
  383. client.messages.create = _get_wrapper( # type: ignore[method-assign]
  384. client.messages.create,
  385. "ChatAnthropic",
  386. _reduce_chat_chunks,
  387. tracing_extra,
  388. )
  389. client.messages.stream = _get_stream_wrapper( # type: ignore[method-assign]
  390. client.messages.stream,
  391. "ChatAnthropic",
  392. tracing_extra,
  393. )
  394. client.completions.create = _get_wrapper( # type: ignore[method-assign]
  395. client.completions.create,
  396. "Anthropic",
  397. _reduce_completions,
  398. tracing_extra,
  399. )
  400. if (
  401. hasattr(client, "beta")
  402. and hasattr(client.beta, "messages")
  403. and hasattr(client.beta.messages, "create")
  404. ):
  405. client.beta.messages.create = _get_wrapper( # type: ignore[method-assign]
  406. client.beta.messages.create, # type: ignore
  407. "ChatAnthropic",
  408. _reduce_chat_chunks,
  409. tracing_extra,
  410. )
  411. return client