_gemini.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679
  1. from __future__ import annotations
  2. import base64
  3. import functools
  4. import json
  5. import logging
  6. from collections.abc import Mapping
  7. from typing import (
  8. TYPE_CHECKING,
  9. Any,
  10. Callable,
  11. Optional,
  12. TypeVar,
  13. Union,
  14. )
  15. from typing_extensions import TypedDict
  16. from langsmith import client as ls_client
  17. from langsmith import run_helpers
  18. from langsmith._internal._beta_decorator import warn_beta
  19. from langsmith.schemas import InputTokenDetails, OutputTokenDetails, UsageMetadata
  20. if TYPE_CHECKING:
  21. from google import genai # type: ignore[import-untyped, attr-defined]
  22. C = TypeVar("C", bound=Union["genai.Client", Any])
  23. logger = logging.getLogger(__name__)
  24. def _strip_none(d: dict) -> dict:
  25. """Remove `None` values from dictionary."""
  26. return {k: v for k, v in d.items() if v is not None}
  27. def _convert_config_for_tracing(kwargs: dict) -> None:
  28. """Convert `GenerateContentConfig` to `dict` for LangSmith compatibility."""
  29. if "config" in kwargs and not isinstance(kwargs["config"], dict):
  30. kwargs["config"] = vars(kwargs["config"])
  31. def _process_gemini_inputs(inputs: dict) -> dict:
  32. r"""Process Gemini inputs to normalize them for LangSmith tracing.
  33. Example:
  34. ```txt
  35. {"contents": "Hello", "model": "gemini-pro"}
  36. → {"messages": [{"role": "user", "content": "Hello"}], "model": "gemini-pro"}
  37. {"contents": [{"role": "user", "parts": [{"text": "What is AI?"}]}], "model": "gemini-pro"}
  38. → {"messages": [{"role": "user", "content": "What is AI?"}], "model": "gemini-pro"}
  39. ```
  40. """ # noqa: E501
  41. # If contents is not present or not in list format, return as-is
  42. contents = inputs.get("contents")
  43. if not contents:
  44. return inputs
  45. # Handle string input (simple case)
  46. if isinstance(contents, str):
  47. return {
  48. "messages": [{"role": "user", "content": contents}],
  49. "model": inputs.get("model"),
  50. **({k: v for k, v in inputs.items() if k not in ("contents", "model")}),
  51. }
  52. # Handle list of content objects (multimodal case)
  53. if isinstance(contents, list):
  54. # Check if it's a simple list of strings
  55. if all(isinstance(item, str) for item in contents):
  56. # Each string becomes a separate user message (matches Gemini's behavior)
  57. return {
  58. "messages": [{"role": "user", "content": item} for item in contents],
  59. "model": inputs.get("model"),
  60. **({k: v for k, v in inputs.items() if k not in ("contents", "model")}),
  61. }
  62. # Handle complex multimodal case
  63. messages = []
  64. for content in contents:
  65. if isinstance(content, dict):
  66. role = content.get("role", "user")
  67. parts = content.get("parts", [])
  68. # Extract text and other parts
  69. text_parts = []
  70. content_parts = []
  71. for part in parts:
  72. if isinstance(part, dict):
  73. # Handle text parts
  74. if "text" in part and part["text"]:
  75. text_parts.append(part["text"])
  76. content_parts.append({"type": "text", "text": part["text"]})
  77. # Handle inline data (images)
  78. elif "inline_data" in part:
  79. inline_data = part["inline_data"]
  80. mime_type = inline_data.get("mime_type", "image/jpeg")
  81. data = inline_data.get("data", b"")
  82. # Convert bytes to base64 string if needed
  83. if isinstance(data, bytes):
  84. data_b64 = base64.b64encode(data).decode("utf-8")
  85. else:
  86. data_b64 = data # Already a string
  87. content_parts.append(
  88. {
  89. "type": "image_url",
  90. "image_url": {
  91. "url": f"data:{mime_type};base64,{data_b64}",
  92. "detail": "high",
  93. },
  94. }
  95. )
  96. # Handle function responses
  97. elif "functionResponse" in part:
  98. function_response = part["functionResponse"]
  99. content_parts.append(
  100. {
  101. "type": "function_response",
  102. "function_response": {
  103. "name": function_response.get("name"),
  104. "response": function_response.get(
  105. "response", {}
  106. ),
  107. },
  108. }
  109. )
  110. # Handle function calls (for conversation history)
  111. elif "function_call" in part or "functionCall" in part:
  112. function_call = part.get("function_call") or part.get(
  113. "functionCall"
  114. )
  115. if function_call is not None:
  116. # Normalize to dict (FunctionCall is a Pydantic model)
  117. if not isinstance(function_call, dict):
  118. function_call = function_call.to_dict()
  119. content_parts.append(
  120. {
  121. "type": "function_call",
  122. "function_call": {
  123. "id": function_call.get("id"),
  124. "name": function_call.get("name"),
  125. "arguments": function_call.get("args", {}),
  126. },
  127. }
  128. )
  129. elif isinstance(part, str):
  130. # Handle simple string parts
  131. text_parts.append(part)
  132. content_parts.append({"type": "text", "text": part})
  133. # If only text parts, use simple string format
  134. if content_parts and all(
  135. p.get("type") == "text" for p in content_parts
  136. ):
  137. message_content: Union[str, list[dict[str, Any]]] = "\n".join(
  138. text_parts
  139. )
  140. else:
  141. message_content = content_parts if content_parts else ""
  142. messages.append({"role": role, "content": message_content})
  143. return {
  144. "messages": messages,
  145. "model": inputs.get("model"),
  146. **({k: v for k, v in inputs.items() if k not in ("contents", "model")}),
  147. }
  148. # Fallback: return original inputs
  149. return inputs
  150. def _infer_invocation_params(kwargs: dict) -> dict:
  151. """Extract invocation parameters for tracing."""
  152. stripped = _strip_none(kwargs)
  153. config = stripped.get("config", {})
  154. # Handle both dict config and GenerateContentConfig object
  155. if hasattr(config, "temperature"):
  156. temperature = config.temperature
  157. max_tokens = getattr(config, "max_output_tokens", None)
  158. stop = getattr(config, "stop_sequences", None)
  159. else:
  160. temperature = config.get("temperature")
  161. max_tokens = config.get("max_output_tokens")
  162. stop = config.get("stop_sequences")
  163. return {
  164. "ls_provider": "google",
  165. "ls_model_type": "chat",
  166. "ls_model_name": stripped.get("model"),
  167. "ls_temperature": temperature,
  168. "ls_max_tokens": max_tokens,
  169. "ls_stop": stop,
  170. }
  171. def _create_usage_metadata(gemini_usage_metadata: dict) -> UsageMetadata:
  172. """Convert Gemini usage metadata to LangSmith format."""
  173. prompt_token_count = gemini_usage_metadata.get("prompt_token_count") or 0
  174. candidates_token_count = gemini_usage_metadata.get("candidates_token_count") or 0
  175. cached_content_token_count = (
  176. gemini_usage_metadata.get("cached_content_token_count") or 0
  177. )
  178. thoughts_token_count = gemini_usage_metadata.get("thoughts_token_count") or 0
  179. total_token_count = (
  180. gemini_usage_metadata.get("total_token_count")
  181. or prompt_token_count + candidates_token_count
  182. )
  183. input_token_details: dict = {}
  184. if cached_content_token_count:
  185. input_token_details["cache_read"] = cached_content_token_count
  186. output_token_details: dict = {}
  187. if thoughts_token_count:
  188. output_token_details["reasoning"] = thoughts_token_count
  189. return UsageMetadata(
  190. input_tokens=prompt_token_count,
  191. output_tokens=candidates_token_count,
  192. total_tokens=total_token_count,
  193. input_token_details=InputTokenDetails(
  194. **{k: v for k, v in input_token_details.items() if v is not None}
  195. ),
  196. output_token_details=OutputTokenDetails(
  197. **{k: v for k, v in output_token_details.items() if v is not None}
  198. ),
  199. )
  200. def _process_generate_content_response(response: Any) -> dict:
  201. """Process Gemini response for tracing."""
  202. try:
  203. # Convert response to dictionary
  204. if hasattr(response, "to_dict"):
  205. rdict = response.to_dict()
  206. elif hasattr(response, "model_dump"):
  207. rdict = response.model_dump()
  208. else:
  209. rdict = {"text": getattr(response, "text", str(response))}
  210. # Extract content from candidates if available
  211. content_result = ""
  212. content_parts = []
  213. finish_reason: Optional[str] = None
  214. if "candidates" in rdict and rdict["candidates"]:
  215. candidate = rdict["candidates"][0]
  216. if "content" in candidate:
  217. content = candidate["content"]
  218. if "parts" in content and content["parts"]:
  219. for part in content["parts"]:
  220. # Handle text parts
  221. if "text" in part and part["text"]:
  222. content_result += part["text"]
  223. content_parts.append({"type": "text", "text": part["text"]})
  224. # Handle inline data (images) in response
  225. elif "inline_data" in part and part["inline_data"] is not None:
  226. inline_data = part["inline_data"]
  227. mime_type = inline_data.get("mime_type", "image/jpeg")
  228. data = inline_data.get("data", b"")
  229. # Convert bytes to base64 string if needed
  230. if isinstance(data, bytes):
  231. data_b64 = base64.b64encode(data).decode("utf-8")
  232. else:
  233. data_b64 = data # Already a string
  234. content_parts.append(
  235. {
  236. "type": "image_url",
  237. "image_url": {
  238. "url": f"data:{mime_type};base64,{data_b64}",
  239. "detail": "high",
  240. },
  241. }
  242. )
  243. # Handle function calls in response
  244. elif "function_call" in part or "functionCall" in part:
  245. function_call = part.get("function_call") or part.get(
  246. "functionCall"
  247. )
  248. if function_call is not None:
  249. # Normalize to dict (FunctionCall is a Pydantic model)
  250. if not isinstance(function_call, dict):
  251. function_call = function_call.to_dict()
  252. content_parts.append(
  253. {
  254. "type": "function_call",
  255. "function_call": {
  256. "id": function_call.get("id"),
  257. "name": function_call.get("name"),
  258. "arguments": function_call.get("args", {}),
  259. },
  260. }
  261. )
  262. if "finish_reason" in candidate and candidate["finish_reason"]:
  263. finish_reason = candidate["finish_reason"]
  264. elif "text" in rdict:
  265. content_result = rdict["text"]
  266. content_parts.append({"type": "text", "text": content_result})
  267. # Build chat-like response format - use OpenAI-compatible format for tool calls
  268. tool_calls = [p for p in content_parts if p.get("type") == "function_call"]
  269. if tool_calls:
  270. # OpenAI-compatible format for LangSmith UI
  271. result = {
  272. "content": content_result or None,
  273. "role": "assistant",
  274. "finish_reason": finish_reason,
  275. "tool_calls": [
  276. {
  277. "id": tc["function_call"].get("id") or f"call_{i}",
  278. "type": "function",
  279. "index": i,
  280. "function": {
  281. "name": tc["function_call"]["name"],
  282. "arguments": json.dumps(tc["function_call"]["arguments"]),
  283. },
  284. }
  285. for i, tc in enumerate(tool_calls)
  286. ],
  287. }
  288. elif len(content_parts) > 1 or (
  289. content_parts and content_parts[0]["type"] != "text"
  290. ):
  291. # Use structured format for mixed non-tool content
  292. result = {
  293. "content": content_parts,
  294. "role": "assistant",
  295. "finish_reason": finish_reason,
  296. }
  297. else:
  298. # Use simple string format for text-only responses
  299. result = {
  300. "content": content_result,
  301. "role": "assistant",
  302. "finish_reason": finish_reason,
  303. }
  304. # Extract and convert usage metadata
  305. usage_metadata = rdict.get("usage_metadata")
  306. usage_dict: UsageMetadata = UsageMetadata(
  307. input_tokens=0, output_tokens=0, total_tokens=0
  308. )
  309. if usage_metadata:
  310. usage_dict = _create_usage_metadata(usage_metadata)
  311. # Add usage_metadata to both run.extra AND outputs
  312. current_run = run_helpers.get_current_run_tree()
  313. if current_run:
  314. try:
  315. meta = current_run.extra.setdefault("metadata", {}).setdefault(
  316. "usage_metadata", {}
  317. )
  318. meta.update(usage_dict)
  319. current_run.patch()
  320. except Exception as e:
  321. logger.warning(f"Failed to update usage metadata: {e}")
  322. # Return in a format that avoids stringification by LangSmith
  323. if result.get("tool_calls"):
  324. # For responses with tool calls, return structured format
  325. return {
  326. "content": result["content"],
  327. "role": "assistant",
  328. "finish_reason": finish_reason,
  329. "tool_calls": result["tool_calls"],
  330. "usage_metadata": usage_dict,
  331. }
  332. else:
  333. # For simple text responses, return minimal structure with usage metadata
  334. if isinstance(result["content"], str):
  335. return {
  336. "content": result["content"],
  337. "role": "assistant",
  338. "finish_reason": finish_reason,
  339. "usage_metadata": usage_dict,
  340. }
  341. else:
  342. # For multimodal content, return structured format with usage metadata
  343. return {
  344. "content": result["content"],
  345. "role": "assistant",
  346. "finish_reason": finish_reason,
  347. "usage_metadata": usage_dict,
  348. }
  349. except Exception as e:
  350. logger.debug(f"Error processing Gemini response: {e}")
  351. return {"output": response}
  352. def _reduce_generate_content_chunks(all_chunks: list) -> dict:
  353. """Reduce streaming chunks into a single response."""
  354. if not all_chunks:
  355. return {
  356. "content": "",
  357. "usage_metadata": UsageMetadata(
  358. input_tokens=0, output_tokens=0, total_tokens=0
  359. ),
  360. }
  361. # Accumulate text from all chunks
  362. full_text = ""
  363. last_chunk = None
  364. for chunk in all_chunks:
  365. try:
  366. if hasattr(chunk, "text") and chunk.text:
  367. full_text += chunk.text
  368. last_chunk = chunk
  369. except Exception as e:
  370. logger.debug(f"Error processing chunk: {e}")
  371. # Extract usage metadata from the last chunk
  372. usage_metadata: UsageMetadata = UsageMetadata(
  373. input_tokens=0, output_tokens=0, total_tokens=0
  374. )
  375. if last_chunk:
  376. try:
  377. if hasattr(last_chunk, "usage_metadata") and last_chunk.usage_metadata:
  378. if hasattr(last_chunk.usage_metadata, "to_dict"):
  379. usage_dict = last_chunk.usage_metadata.to_dict()
  380. elif hasattr(last_chunk.usage_metadata, "model_dump"):
  381. usage_dict = last_chunk.usage_metadata.model_dump()
  382. else:
  383. usage_dict = {
  384. "prompt_token_count": getattr(
  385. last_chunk.usage_metadata, "prompt_token_count", 0
  386. ),
  387. "candidates_token_count": getattr(
  388. last_chunk.usage_metadata, "candidates_token_count", 0
  389. ),
  390. "cached_content_token_count": getattr(
  391. last_chunk.usage_metadata, "cached_content_token_count", 0
  392. ),
  393. "thoughts_token_count": getattr(
  394. last_chunk.usage_metadata, "thoughts_token_count", 0
  395. ),
  396. "total_token_count": getattr(
  397. last_chunk.usage_metadata, "total_token_count", 0
  398. ),
  399. }
  400. # Add usage_metadata to both run.extra AND outputs
  401. usage_metadata = _create_usage_metadata(usage_dict)
  402. current_run = run_helpers.get_current_run_tree()
  403. if current_run:
  404. try:
  405. meta = current_run.extra.setdefault("metadata", {}).setdefault(
  406. "usage_metadata", {}
  407. )
  408. meta.update(usage_metadata)
  409. current_run.patch()
  410. except Exception as e:
  411. logger.warning(f"Failed to update usage metadata: {e}")
  412. except Exception as e:
  413. logger.debug(f"Error extracting metadata from last chunk: {e}")
  414. # Return minimal structure with usage_metadata in outputs
  415. return {
  416. "content": full_text,
  417. "usage_metadata": usage_metadata,
  418. }
  419. def _get_wrapper(
  420. original_generate: Callable,
  421. name: str,
  422. tracing_extra: Optional[TracingExtra] = None,
  423. is_streaming: bool = False,
  424. ) -> Callable:
  425. """Create a wrapper for Gemini's `generate_content` methods."""
  426. textra = tracing_extra or {}
  427. @functools.wraps(original_generate)
  428. def generate(*args, **kwargs):
  429. # Handle config object before tracing setup
  430. _convert_config_for_tracing(kwargs)
  431. decorator = run_helpers.traceable(
  432. name=name,
  433. run_type="llm",
  434. reduce_fn=_reduce_generate_content_chunks if is_streaming else None,
  435. process_inputs=_process_gemini_inputs,
  436. process_outputs=(
  437. _process_generate_content_response if not is_streaming else None
  438. ),
  439. _invocation_params_fn=_infer_invocation_params,
  440. **textra,
  441. )
  442. return decorator(original_generate)(*args, **kwargs)
  443. @functools.wraps(original_generate)
  444. async def agenerate(*args, **kwargs):
  445. # Handle config object before tracing setup
  446. _convert_config_for_tracing(kwargs)
  447. decorator = run_helpers.traceable(
  448. name=name,
  449. run_type="llm",
  450. reduce_fn=_reduce_generate_content_chunks if is_streaming else None,
  451. process_inputs=_process_gemini_inputs,
  452. process_outputs=(
  453. _process_generate_content_response if not is_streaming else None
  454. ),
  455. _invocation_params_fn=_infer_invocation_params,
  456. **textra,
  457. )
  458. return await decorator(original_generate)(*args, **kwargs)
  459. return agenerate if run_helpers.is_async(original_generate) else generate
  460. class TracingExtra(TypedDict, total=False):
  461. metadata: Optional[Mapping[str, Any]]
  462. tags: Optional[list[str]]
  463. client: Optional[ls_client.Client]
  464. @warn_beta
  465. def wrap_gemini(
  466. client: C,
  467. *,
  468. tracing_extra: Optional[TracingExtra] = None,
  469. chat_name: str = "ChatGoogleGenerativeAI",
  470. ) -> C:
  471. """Patch the Google Gen AI client to make it traceable.
  472. !!! warning
  473. **BETA**: This wrapper is in beta.
  474. Supports:
  475. - `generate_content` and `generate_content_stream` methods
  476. - Sync and async clients
  477. - Streaming and non-streaming responses
  478. - Tool/function calling with proper UI rendering
  479. - Multimodal inputs (text + images)
  480. - Image generation with `inline_data` support
  481. - Token usage tracking including reasoning tokens
  482. Args:
  483. client: The Google Gen AI client to patch.
  484. tracing_extra: Extra tracing information.
  485. chat_name: The run name for the chat endpoint.
  486. Returns:
  487. The patched client.
  488. Example:
  489. ```python
  490. from google import genai
  491. from google.genai import types
  492. from langsmith import wrappers
  493. # Use Google Gen AI client same as you normally would.
  494. client = wrappers.wrap_gemini(genai.Client(api_key="your-api-key"))
  495. # Basic text generation:
  496. response = client.models.generate_content(
  497. model="gemini-2.5-flash",
  498. contents="Why is the sky blue?",
  499. )
  500. print(response.text)
  501. # Streaming:
  502. for chunk in client.models.generate_content_stream(
  503. model="gemini-2.5-flash",
  504. contents="Tell me a story",
  505. ):
  506. print(chunk.text, end="")
  507. # Tool/Function calling:
  508. schedule_meeting_function = {
  509. "name": "schedule_meeting",
  510. "description": "Schedules a meeting with specified attendees.",
  511. "parameters": {
  512. "type": "object",
  513. "properties": {
  514. "attendees": {"type": "array", "items": {"type": "string"}},
  515. "date": {"type": "string"},
  516. "time": {"type": "string"},
  517. "topic": {"type": "string"},
  518. },
  519. "required": ["attendees", "date", "time", "topic"],
  520. },
  521. }
  522. tools = types.Tool(function_declarations=[schedule_meeting_function])
  523. config = types.GenerateContentConfig(tools=[tools])
  524. response = client.models.generate_content(
  525. model="gemini-2.5-flash",
  526. contents="Schedule a meeting with Bob and Alice tomorrow at 2 PM.",
  527. config=config,
  528. )
  529. # Image generation:
  530. response = client.models.generate_content(
  531. model="gemini-2.5-flash-image",
  532. contents=["Create a picture of a futuristic city"],
  533. )
  534. # Save generated image
  535. from io import BytesIO
  536. from PIL import Image
  537. for part in response.candidates[0].content.parts:
  538. if part.inline_data is not None:
  539. image = Image.open(BytesIO(part.inline_data.data))
  540. image.save("generated_image.png")
  541. ```
  542. !!! version-added "Added in `langsmith` 0.4.33"
  543. Initial beta release of Google Gemini wrapper.
  544. """
  545. tracing_extra = tracing_extra or {}
  546. # Check if already wrapped to prevent double-wrapping
  547. if (
  548. hasattr(client, "models")
  549. and hasattr(client.models, "generate_content")
  550. and hasattr(client.models.generate_content, "__wrapped__")
  551. ):
  552. raise ValueError(
  553. "This Google Gen AI client has already been wrapped. "
  554. "Wrapping a client multiple times is not supported."
  555. )
  556. # Wrap synchronous methods
  557. if hasattr(client, "models") and hasattr(client.models, "generate_content"):
  558. client.models.generate_content = _get_wrapper( # type: ignore[method-assign]
  559. client.models.generate_content,
  560. chat_name,
  561. tracing_extra=tracing_extra,
  562. is_streaming=False,
  563. )
  564. if hasattr(client, "models") and hasattr(client.models, "generate_content_stream"):
  565. client.models.generate_content_stream = _get_wrapper( # type: ignore[method-assign]
  566. client.models.generate_content_stream,
  567. chat_name,
  568. tracing_extra=tracing_extra,
  569. is_streaming=True,
  570. )
  571. # Wrap async methods (aio namespace)
  572. if (
  573. hasattr(client, "aio")
  574. and hasattr(client.aio, "models")
  575. and hasattr(client.aio.models, "generate_content")
  576. ):
  577. client.aio.models.generate_content = _get_wrapper( # type: ignore[method-assign]
  578. client.aio.models.generate_content,
  579. chat_name,
  580. tracing_extra=tracing_extra,
  581. is_streaming=False,
  582. )
  583. if (
  584. hasattr(client, "aio")
  585. and hasattr(client.aio, "models")
  586. and hasattr(client.aio.models, "generate_content_stream")
  587. ):
  588. client.aio.models.generate_content_stream = _get_wrapper( # type: ignore[method-assign]
  589. client.aio.models.generate_content_stream,
  590. chat_name,
  591. tracing_extra=tracing_extra,
  592. is_streaming=True,
  593. )
  594. return client