_internal.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460
  1. from __future__ import annotations
  2. import atexit
  3. import contextlib
  4. import contextvars
  5. import datetime
  6. import functools
  7. import hashlib
  8. import importlib
  9. import inspect
  10. import logging
  11. import os
  12. import threading
  13. import time
  14. import uuid
  15. import warnings
  16. from collections.abc import Generator, Sequence
  17. from concurrent.futures import Future
  18. from pathlib import Path
  19. from typing import (
  20. Any,
  21. Callable,
  22. Optional,
  23. TypeVar,
  24. Union,
  25. cast,
  26. overload,
  27. )
  28. from typing_extensions import TypedDict
  29. from langsmith import client as ls_client
  30. from langsmith import env as ls_env
  31. from langsmith import run_helpers as rh
  32. from langsmith import run_trees
  33. from langsmith import run_trees as rt
  34. from langsmith import schemas as ls_schemas
  35. from langsmith import utils as ls_utils
  36. from langsmith._internal import _orjson
  37. from langsmith._internal._serde import dumps_json
  38. from langsmith.client import ID_TYPE
  39. try:
  40. import pytest # type: ignore
  41. SkipException = pytest.skip.Exception
  42. except ImportError:
  43. class SkipException(Exception): # type: ignore[no-redef]
  44. pass
  45. logger = logging.getLogger(__name__)
  46. # UUID5 namespace used for generating consistent example IDs
  47. UUID5_NAMESPACE = uuid.UUID("6ba7b810-9dad-11d1-80b4-00c04fd430c8")
  48. T = TypeVar("T")
  49. U = TypeVar("U")
  50. def _object_hash(obj: Any) -> str:
  51. """Hash an object to generate a consistent hash string."""
  52. # Use the existing serialization infrastructure with consistent ordering
  53. serialized = _stringify(obj)
  54. return hashlib.sha256(serialized.encode()).hexdigest()
  55. @overload
  56. def test(
  57. func: Callable,
  58. ) -> Callable: ...
  59. @overload
  60. def test(
  61. *,
  62. id: Optional[uuid.UUID] = None,
  63. output_keys: Optional[Sequence[str]] = None,
  64. client: Optional[ls_client.Client] = None,
  65. test_suite_name: Optional[str] = None,
  66. metadata: Optional[dict] = None,
  67. repetitions: Optional[int] = None,
  68. split: Optional[Union[str | list[str]]] = None,
  69. cached_hosts: Optional[Sequence[str]] = None,
  70. ) -> Callable[[Callable], Callable]: ...
  71. def test(*args: Any, **kwargs: Any) -> Callable:
  72. """Trace a pytest test case in LangSmith.
  73. This decorator is used to trace a pytest test to LangSmith. It ensures
  74. that the necessary example data is created and associated with the test function.
  75. The decorated function will be executed as a test case, and the results will be
  76. recorded and reported by LangSmith.
  77. Args:
  78. - id (Optional[uuid.UUID]): A unique identifier for the test case. If not
  79. provided, an ID will be generated based on the test function's module
  80. and name.
  81. - output_keys (Optional[Sequence[str]]): A list of keys to be considered as
  82. the output keys for the test case. These keys will be extracted from the
  83. test function's inputs and stored as the expected outputs.
  84. - client (Optional[ls_client.Client]): An instance of the LangSmith client
  85. to be used for communication with the LangSmith service. If not provided,
  86. a default client will be used.
  87. - test_suite_name (Optional[str]): The name of the test suite to which the
  88. test case belongs. If not provided, the test suite name will be determined
  89. based on the environment or the package name.
  90. - cached_hosts (Optional[Sequence[str]]): A list of hosts or URL prefixes to
  91. cache requests to during testing. If not provided, all requests will be
  92. cached (default behavior). This is useful for caching only specific
  93. API calls (e.g., ["api.openai.com"] or ["https://api.openai.com"]).
  94. Returns:
  95. Callable: The decorated test function.
  96. Environment:
  97. - `LANGSMITH_TEST_CACHE`: If set, API calls will be cached to disk to
  98. save time and costs during testing. Recommended to commit the
  99. cache files to your repository for faster CI/CD runs.
  100. Requires the 'langsmith[vcr]' package to be installed.
  101. - `LANGSMITH_TEST_TRACKING`: Set this variable to the path of a directory
  102. to enable caching of test results. This is useful for re-running tests
  103. without re-executing the code. Requires the 'langsmith[vcr]' package.
  104. Example:
  105. For basic usage, simply decorate a test function with `@pytest.mark.langsmith`.
  106. Under the hood this will call the `test` method:
  107. ```python
  108. import pytest
  109. # Equivalently can decorate with `test` directly:
  110. # from langsmith import test
  111. # @test
  112. @pytest.mark.langsmith
  113. def test_addition():
  114. assert 3 + 4 == 7
  115. ```
  116. Any code that is traced (such as those traced using `@traceable`
  117. or `wrap_*` functions) will be traced within the test case for
  118. improved visibility and debugging.
  119. ```python
  120. import pytest
  121. from langsmith import traceable
  122. @traceable
  123. def generate_numbers():
  124. return 3, 4
  125. @pytest.mark.langsmith
  126. def test_nested():
  127. # Traced code will be included in the test case
  128. a, b = generate_numbers()
  129. assert a + b == 7
  130. ```
  131. LLM calls are expensive! Cache requests by setting
  132. `LANGSMITH_TEST_CACHE=path/to/cache`. Check in these files to speed up
  133. CI/CD pipelines, so your results only change when your prompt or requested
  134. model changes.
  135. Note that this will require that you install langsmith with the `vcr` extra:
  136. `pip install -U "langsmith[vcr]"`
  137. Caching is faster if you install libyaml. See
  138. https://vcrpy.readthedocs.io/en/latest/installation.html#speed for more details.
  139. ```python
  140. # os.environ["LANGSMITH_TEST_CACHE"] = "tests/cassettes"
  141. import openai
  142. import pytest
  143. from langsmith import wrappers
  144. oai_client = wrappers.wrap_openai(openai.Client())
  145. @pytest.mark.langsmith
  146. def test_openai_says_hello():
  147. # Traced code will be included in the test case
  148. response = oai_client.chat.completions.create(
  149. model="gpt-3.5-turbo",
  150. messages=[
  151. {"role": "system", "content": "You are a helpful assistant."},
  152. {"role": "user", "content": "Say hello!"},
  153. ],
  154. )
  155. assert "hello" in response.choices[0].message.content.lower()
  156. ```
  157. You can also specify which hosts to cache by using the `cached_hosts` parameter.
  158. This is useful when you only want to cache specific API calls:
  159. ```python
  160. @pytest.mark.langsmith(cached_hosts=["https://api.openai.com"])
  161. def test_openai_with_selective_caching():
  162. # Only OpenAI API calls will be cached, other API calls will not
  163. # be cached
  164. response = oai_client.chat.completions.create(
  165. model="gpt-3.5-turbo",
  166. messages=[
  167. {"role": "system", "content": "You are a helpful assistant."},
  168. {"role": "user", "content": "Say hello!"},
  169. ],
  170. )
  171. assert "hello" in response.choices[0].message.content.lower()
  172. ```
  173. LLMs are stochastic. Naive assertions are flakey. You can use langsmith's
  174. `expect` to score and make approximate assertions on your results.
  175. ```python
  176. import pytest
  177. from langsmith import expect
  178. @pytest.mark.langsmith
  179. def test_output_semantically_close():
  180. response = oai_client.chat.completions.create(
  181. model="gpt-3.5-turbo",
  182. messages=[
  183. {"role": "system", "content": "You are a helpful assistant."},
  184. {"role": "user", "content": "Say hello!"},
  185. ],
  186. )
  187. # The embedding_distance call logs the embedding distance to LangSmith
  188. expect.embedding_distance(
  189. prediction=response.choices[0].message.content,
  190. reference="Hello!",
  191. # The following optional assertion logs a
  192. # pass/fail score to LangSmith
  193. # and raises an AssertionError if the assertion fails.
  194. ).to_be_less_than(1.0)
  195. # Compute damerau_levenshtein distance
  196. expect.edit_distance(
  197. prediction=response.choices[0].message.content,
  198. reference="Hello!",
  199. # And then log a pass/fail score to LangSmith
  200. ).to_be_less_than(1.0)
  201. ```
  202. The `@test` decorator works natively with pytest fixtures.
  203. The values will populate the "inputs" of the corresponding example in LangSmith.
  204. ```python
  205. import pytest
  206. @pytest.fixture
  207. def some_input():
  208. return "Some input"
  209. @pytest.mark.langsmith
  210. def test_with_fixture(some_input: str):
  211. assert "input" in some_input
  212. ```
  213. You can still use `pytest.parametrize()` as usual to run multiple test cases
  214. using the same test function.
  215. ```python
  216. import pytest
  217. @pytest.mark.langsmith(output_keys=["expected"])
  218. @pytest.mark.parametrize(
  219. "a, b, expected",
  220. [
  221. (1, 2, 3),
  222. (3, 4, 7),
  223. ],
  224. )
  225. def test_addition_with_multiple_inputs(a: int, b: int, expected: int):
  226. assert a + b == expected
  227. ```
  228. By default, each test case will be assigned a consistent, unique identifier
  229. based on the function name and module. You can also provide a custom identifier
  230. using the `id` argument:
  231. ```python
  232. import pytest
  233. import uuid
  234. example_id = uuid.uuid4()
  235. @pytest.mark.langsmith(id=str(example_id))
  236. def test_multiplication():
  237. assert 3 * 4 == 12
  238. ```
  239. By default, all test inputs are saved as "inputs" to a dataset.
  240. You can specify the `output_keys` argument to persist those keys
  241. within the dataset's "outputs" fields.
  242. ```python
  243. import pytest
  244. @pytest.fixture
  245. def expected_output():
  246. return "input"
  247. @pytest.mark.langsmith(output_keys=["expected_output"])
  248. def test_with_expected_output(some_input: str, expected_output: str):
  249. assert expected_output in some_input
  250. ```
  251. To run these tests, use the pytest CLI. Or directly run the test functions.
  252. ```python
  253. test_output_semantically_close()
  254. test_addition()
  255. test_nested()
  256. test_with_fixture("Some input")
  257. test_with_expected_output("Some input", "Some")
  258. test_multiplication()
  259. test_openai_says_hello()
  260. test_addition_with_multiple_inputs(1, 2, 3)
  261. ```
  262. """
  263. cached_hosts = kwargs.pop("cached_hosts", None)
  264. cache_dir = ls_utils.get_cache_dir(kwargs.pop("cache", None))
  265. # Validate cached_hosts usage
  266. if cached_hosts and not cache_dir:
  267. raise ValueError(
  268. "cached_hosts parameter requires caching to be enabled. "
  269. "Please set the LANGSMITH_TEST_CACHE environment variable "
  270. "to a cache directory path, "
  271. "or pass a cache parameter to the test decorator. "
  272. "Example: LANGSMITH_TEST_CACHE='tests/cassettes' "
  273. "or @pytest.mark.langsmith(cache='tests/cassettes', cached_hosts=[...])"
  274. )
  275. langtest_extra = _UTExtra(
  276. id=kwargs.pop("id", None),
  277. output_keys=kwargs.pop("output_keys", None),
  278. client=kwargs.pop("client", None),
  279. test_suite_name=kwargs.pop("test_suite_name", None),
  280. cache=cache_dir,
  281. metadata=kwargs.pop("metadata", None),
  282. repetitions=kwargs.pop("repetitions", None),
  283. split=kwargs.pop("split", None),
  284. cached_hosts=cached_hosts,
  285. )
  286. if kwargs:
  287. warnings.warn(f"Unexpected keyword arguments: {kwargs.keys()}")
  288. disable_tracking = ls_utils.test_tracking_is_disabled()
  289. if disable_tracking:
  290. logger.info(
  291. "LANGSMITH_TEST_TRACKING is set to 'false'."
  292. " Skipping LangSmith test tracking."
  293. )
  294. def decorator(func: Callable) -> Callable:
  295. # Handle repetitions
  296. repetitions = langtest_extra.get("repetitions", 1) or 1
  297. if inspect.iscoroutinefunction(func):
  298. @functools.wraps(func)
  299. async def async_wrapper(
  300. *test_args: Any, request: Any = None, **test_kwargs: Any
  301. ):
  302. if disable_tracking:
  303. return await func(*test_args, **test_kwargs)
  304. # Run test multiple times for repetitions
  305. for i in range(repetitions):
  306. repetition_extra = langtest_extra.copy()
  307. await _arun_test(
  308. func,
  309. *test_args,
  310. pytest_request=request,
  311. **test_kwargs,
  312. langtest_extra=repetition_extra,
  313. )
  314. return async_wrapper
  315. @functools.wraps(func)
  316. def wrapper(*test_args: Any, request: Any = None, **test_kwargs: Any):
  317. if disable_tracking:
  318. return func(*test_args, **test_kwargs)
  319. # Run test multiple times for repetitions
  320. for i in range(repetitions):
  321. repetition_extra = langtest_extra.copy()
  322. _run_test(
  323. func,
  324. *test_args,
  325. pytest_request=request,
  326. **test_kwargs,
  327. langtest_extra=repetition_extra,
  328. )
  329. return wrapper
  330. if args and callable(args[0]):
  331. return decorator(args[0])
  332. return decorator
  333. ## Private functions
  334. def _get_experiment_name(test_suite_name: str) -> str:
  335. # If this is a pytest-xdist multi-process run then we need to create the same
  336. # experiment name across processes. We can do this by accessing the
  337. # PYTEST_XDIST_TESTRUNID env var.
  338. if os.environ.get("PYTEST_XDIST_TESTRUNUID") and importlib.util.find_spec("xdist"):
  339. id_name = test_suite_name + os.environ["PYTEST_XDIST_TESTRUNUID"]
  340. id_ = str(uuid.uuid5(uuid.NAMESPACE_DNS, id_name).hex[:8])
  341. else:
  342. id_ = str(uuid.uuid4().hex[:8])
  343. if os.environ.get("LANGSMITH_EXPERIMENT"):
  344. prefix = os.environ["LANGSMITH_EXPERIMENT"]
  345. else:
  346. prefix = ls_utils.get_tracer_project(False) or "TestSuiteResult"
  347. name = f"{prefix}:{id_}"
  348. return name
  349. def _get_test_suite_name(func: Callable) -> str:
  350. test_suite_name = ls_utils.get_env_var("TEST_SUITE")
  351. if test_suite_name:
  352. return test_suite_name
  353. repo_name = ls_env.get_git_info()["repo_name"]
  354. try:
  355. mod = inspect.getmodule(func)
  356. if mod:
  357. return f"{repo_name}.{mod.__name__}"
  358. except BaseException:
  359. logger.debug("Could not determine test suite name from file path.")
  360. raise ValueError("Please set the LANGSMITH_TEST_SUITE environment variable.")
  361. def _get_test_suite(
  362. client: ls_client.Client, test_suite_name: str
  363. ) -> ls_schemas.Dataset:
  364. if client.has_dataset(dataset_name=test_suite_name):
  365. return client.read_dataset(dataset_name=test_suite_name)
  366. else:
  367. repo = ls_env.get_git_info().get("remote_url") or ""
  368. description = "Test suite"
  369. if repo:
  370. description += f" for {repo}"
  371. try:
  372. return client.create_dataset(
  373. dataset_name=test_suite_name,
  374. description=description,
  375. metadata={"__ls_runner": "pytest"},
  376. )
  377. except ls_utils.LangSmithConflictError:
  378. return client.read_dataset(dataset_name=test_suite_name)
  379. def _start_experiment(
  380. client: ls_client.Client,
  381. test_suite: ls_schemas.Dataset,
  382. ) -> ls_schemas.TracerSession:
  383. experiment_name = _get_experiment_name(test_suite.name)
  384. try:
  385. return client.create_project(
  386. experiment_name,
  387. reference_dataset_id=test_suite.id,
  388. description="Test Suite Results.",
  389. metadata={
  390. "revision_id": ls_env.get_langchain_env_var_metadata().get(
  391. "revision_id"
  392. ),
  393. "__ls_runner": "pytest",
  394. },
  395. )
  396. except ls_utils.LangSmithConflictError:
  397. return client.read_project(project_name=experiment_name)
  398. def _get_example_id(
  399. dataset_id: str,
  400. inputs: dict,
  401. outputs: Optional[dict] = None,
  402. ) -> uuid.UUID:
  403. """Generate example ID based on inputs, outputs, and dataset ID."""
  404. identifier_obj = (dataset_id, _object_hash(inputs), _object_hash(outputs or {}))
  405. identifier = _stringify(identifier_obj)
  406. return uuid.uuid5(UUID5_NAMESPACE, identifier)
  407. def _get_example_id_legacy(
  408. func: Callable, inputs: Optional[dict], suite_id: uuid.UUID
  409. ) -> tuple[uuid.UUID, str]:
  410. try:
  411. file_path = str(Path(inspect.getfile(func)).relative_to(Path.cwd()))
  412. except ValueError:
  413. # Fall back to module name if file path is not available
  414. file_path = func.__module__
  415. identifier = f"{suite_id}{file_path}::{func.__name__}"
  416. # If parametrized test, need to add inputs to identifier:
  417. if hasattr(func, "pytestmark") and any(
  418. m.name == "parametrize" for m in func.pytestmark
  419. ):
  420. identifier += _stringify(inputs)
  421. return uuid.uuid5(uuid.NAMESPACE_DNS, identifier), identifier[len(str(suite_id)) :]
  422. def _end_tests(test_suite: _LangSmithTestSuite):
  423. git_info = ls_env.get_git_info() or {}
  424. test_suite.shutdown()
  425. dataset_version = test_suite.get_dataset_version()
  426. dataset_id = test_suite._dataset.id
  427. test_suite.client.update_project(
  428. test_suite.experiment_id,
  429. metadata={
  430. **git_info,
  431. "dataset_version": dataset_version,
  432. "revision_id": ls_env.get_langchain_env_var_metadata().get("revision_id"),
  433. "__ls_runner": "pytest",
  434. },
  435. )
  436. if dataset_version and git_info["commit"] is not None:
  437. test_suite.client.update_dataset_tag(
  438. dataset_id=dataset_id,
  439. as_of=dataset_version,
  440. tag=f"git:commit:{git_info['commit']}",
  441. )
  442. if dataset_version and git_info["branch"] is not None:
  443. test_suite.client.update_dataset_tag(
  444. dataset_id=dataset_id,
  445. as_of=dataset_version,
  446. tag=f"git:branch:{git_info['branch']}",
  447. )
  448. VT = TypeVar("VT", bound=Optional[dict])
  449. def _serde_example_values(values: VT) -> VT:
  450. if values is None:
  451. return cast(VT, values)
  452. bts = ls_client._dumps_json(values)
  453. return _orjson.loads(bts)
  454. class _LangSmithTestSuite:
  455. _instances: Optional[dict] = None
  456. _lock = threading.RLock()
  457. def __init__(
  458. self,
  459. client: Optional[ls_client.Client],
  460. experiment: ls_schemas.TracerSession,
  461. dataset: ls_schemas.Dataset,
  462. ):
  463. self.client = client or rt.get_cached_client()
  464. self._experiment = experiment
  465. self._dataset = dataset
  466. self._dataset_version: Optional[datetime.datetime] = dataset.modified_at
  467. self._executor = ls_utils.ContextThreadPoolExecutor()
  468. atexit.register(_end_tests, self)
  469. @property
  470. def id(self):
  471. return self._dataset.id
  472. @property
  473. def experiment_id(self):
  474. return self._experiment.id
  475. @property
  476. def experiment(self):
  477. return self._experiment
  478. @classmethod
  479. def from_test(
  480. cls,
  481. client: Optional[ls_client.Client],
  482. func: Callable,
  483. test_suite_name: Optional[str] = None,
  484. ) -> _LangSmithTestSuite:
  485. client = client or rt.get_cached_client()
  486. test_suite_name = test_suite_name or _get_test_suite_name(func)
  487. with cls._lock:
  488. if not cls._instances:
  489. cls._instances = {}
  490. if test_suite_name not in cls._instances:
  491. test_suite = _get_test_suite(client, test_suite_name)
  492. experiment = _start_experiment(client, test_suite)
  493. cls._instances[test_suite_name] = cls(client, experiment, test_suite)
  494. return cls._instances[test_suite_name]
  495. @property
  496. def name(self):
  497. return self._experiment.name
  498. def get_dataset_version(self):
  499. return self._dataset_version
  500. def submit_result(
  501. self,
  502. run_id: uuid.UUID,
  503. error: Optional[str] = None,
  504. skipped: bool = False,
  505. pytest_plugin: Any = None,
  506. pytest_nodeid: Any = None,
  507. ) -> None:
  508. if skipped:
  509. score = None
  510. status = "skipped"
  511. elif error:
  512. score = 0
  513. status = "failed"
  514. else:
  515. score = 1
  516. status = "passed"
  517. if pytest_plugin and pytest_nodeid:
  518. pytest_plugin.update_process_status(pytest_nodeid, {"status": status})
  519. self._executor.submit(self._submit_result, run_id, score)
  520. def _submit_result(self, run_id: uuid.UUID, score: Optional[int]) -> None:
  521. # trace_id will always be run_id here because the feedback is on the root
  522. # test run
  523. self.client.create_feedback(run_id, key="pass", score=score, trace_id=run_id)
  524. def sync_example(
  525. self,
  526. example_id: uuid.UUID,
  527. *,
  528. inputs: Optional[dict] = None,
  529. outputs: Optional[dict] = None,
  530. metadata: Optional[dict] = None,
  531. split: Optional[Union[str, list[str]]] = None,
  532. pytest_plugin=None,
  533. pytest_nodeid=None,
  534. ) -> None:
  535. inputs = inputs or {}
  536. if pytest_plugin and pytest_nodeid:
  537. update = {"inputs": inputs, "reference_outputs": outputs}
  538. update = {k: v for k, v in update.items() if v is not None}
  539. pytest_plugin.update_process_status(pytest_nodeid, update)
  540. metadata = metadata.copy() if metadata else metadata
  541. inputs = _serde_example_values(inputs)
  542. outputs = _serde_example_values(outputs)
  543. try:
  544. example = self.client.read_example(example_id=example_id)
  545. except ls_utils.LangSmithNotFoundError:
  546. example = self.client.create_example(
  547. example_id=example_id,
  548. inputs=inputs,
  549. outputs=outputs,
  550. dataset_id=self.id,
  551. metadata=metadata,
  552. split=split,
  553. created_at=self._experiment.start_time,
  554. )
  555. else:
  556. normalized_split = split
  557. if isinstance(normalized_split, str):
  558. normalized_split = [normalized_split]
  559. if normalized_split and metadata:
  560. metadata["dataset_split"] = normalized_split
  561. existing_dataset_split = (example.metadata or {}).pop("dataset_split")
  562. if (
  563. (inputs != example.inputs)
  564. or (outputs is not None and outputs != example.outputs)
  565. or (metadata is not None and metadata != example.metadata)
  566. or str(example.dataset_id) != str(self.id)
  567. or (
  568. normalized_split is not None
  569. and existing_dataset_split != normalized_split
  570. )
  571. ):
  572. self.client.update_example(
  573. example_id=example.id,
  574. inputs=inputs,
  575. outputs=outputs,
  576. metadata=metadata,
  577. split=split,
  578. dataset_id=self.id,
  579. )
  580. example = self.client.read_example(example_id=example.id)
  581. if self._dataset_version is None:
  582. self._dataset_version = example.modified_at
  583. elif (
  584. example.modified_at
  585. and self._dataset_version
  586. and example.modified_at > self._dataset_version
  587. ):
  588. self._dataset_version = example.modified_at
  589. def _submit_feedback(
  590. self,
  591. run_id: ID_TYPE,
  592. feedback: Union[dict, list],
  593. pytest_plugin: Any = None,
  594. pytest_nodeid: Any = None,
  595. **kwargs: Any,
  596. ):
  597. feedback = feedback if isinstance(feedback, list) else [feedback]
  598. for fb in feedback:
  599. if pytest_plugin and pytest_nodeid:
  600. val = fb["score"] if "score" in fb else fb["value"]
  601. pytest_plugin.update_process_status(
  602. pytest_nodeid, {"feedback": {fb["key"]: val}}
  603. )
  604. self._executor.submit(
  605. self._create_feedback, run_id=run_id, feedback=fb, **kwargs
  606. )
  607. def _create_feedback(self, run_id: ID_TYPE, feedback: dict, **kwargs: Any) -> None:
  608. # trace_id will always be run_id here because the feedback is on the root
  609. # test run
  610. self.client.create_feedback(run_id, **feedback, **kwargs, trace_id=run_id)
  611. def shutdown(self):
  612. self._executor.shutdown()
  613. def end_run(
  614. self,
  615. run_tree,
  616. example_id,
  617. outputs,
  618. reference_outputs,
  619. metadata,
  620. split,
  621. pytest_plugin=None,
  622. pytest_nodeid=None,
  623. ) -> Future:
  624. return self._executor.submit(
  625. self._end_run,
  626. run_tree=run_tree,
  627. example_id=example_id,
  628. outputs=outputs,
  629. reference_outputs=reference_outputs,
  630. metadata=metadata,
  631. split=split,
  632. pytest_plugin=pytest_plugin,
  633. pytest_nodeid=pytest_nodeid,
  634. )
  635. def _end_run(
  636. self,
  637. run_tree,
  638. example_id,
  639. outputs,
  640. reference_outputs,
  641. metadata,
  642. split,
  643. pytest_plugin,
  644. pytest_nodeid,
  645. ) -> None:
  646. # TODO: remove this hack so that run durations are correct
  647. # Ensure example is fully updated
  648. self.sync_example(
  649. example_id,
  650. inputs=run_tree.inputs,
  651. outputs=reference_outputs,
  652. split=split,
  653. metadata=metadata,
  654. )
  655. run_tree.reference_example_id = example_id
  656. run_tree.end(outputs=outputs, metadata={"reference_example_id": example_id})
  657. run_tree.patch()
  658. class _TestCase:
  659. def __init__(
  660. self,
  661. test_suite: _LangSmithTestSuite,
  662. run_id: uuid.UUID,
  663. example_id: Optional[uuid.UUID] = None,
  664. metadata: Optional[dict] = None,
  665. split: Optional[Union[str, list[str]]] = None,
  666. pytest_plugin: Any = None,
  667. pytest_nodeid: Any = None,
  668. inputs: Optional[dict] = None,
  669. reference_outputs: Optional[dict] = None,
  670. ) -> None:
  671. self.test_suite = test_suite
  672. self.example_id = example_id
  673. self.run_id = run_id
  674. self.metadata = metadata
  675. self.split = split
  676. self.pytest_plugin = pytest_plugin
  677. self.pytest_nodeid = pytest_nodeid
  678. self.inputs = inputs
  679. self.reference_outputs = reference_outputs
  680. self._logged_reference_outputs: Optional[dict] = None
  681. self._logged_outputs: Optional[dict] = None
  682. if pytest_plugin and pytest_nodeid:
  683. pytest_plugin.add_process_to_test_suite(
  684. test_suite._dataset.name, pytest_nodeid
  685. )
  686. if inputs:
  687. self.log_inputs(inputs)
  688. if reference_outputs:
  689. self.log_reference_outputs(reference_outputs)
  690. def submit_feedback(self, *args, **kwargs: Any):
  691. self.test_suite._submit_feedback(
  692. *args,
  693. **{
  694. **kwargs,
  695. **dict(
  696. pytest_plugin=self.pytest_plugin,
  697. pytest_nodeid=self.pytest_nodeid,
  698. ),
  699. },
  700. )
  701. def log_inputs(self, inputs: dict) -> None:
  702. if self.pytest_plugin and self.pytest_nodeid:
  703. self.pytest_plugin.update_process_status(
  704. self.pytest_nodeid, {"inputs": inputs}
  705. )
  706. def log_outputs(self, outputs: dict) -> None:
  707. self._logged_outputs = outputs
  708. if self.pytest_plugin and self.pytest_nodeid:
  709. self.pytest_plugin.update_process_status(
  710. self.pytest_nodeid, {"outputs": outputs}
  711. )
  712. def log_reference_outputs(self, reference_outputs: dict) -> None:
  713. self._logged_reference_outputs = reference_outputs
  714. if self.pytest_plugin and self.pytest_nodeid:
  715. self.pytest_plugin.update_process_status(
  716. self.pytest_nodeid, {"reference_outputs": reference_outputs}
  717. )
  718. def submit_test_result(
  719. self,
  720. error: Optional[str] = None,
  721. skipped: bool = False,
  722. ) -> None:
  723. return self.test_suite.submit_result(
  724. self.run_id,
  725. error=error,
  726. skipped=skipped,
  727. pytest_plugin=self.pytest_plugin,
  728. pytest_nodeid=self.pytest_nodeid,
  729. )
  730. def start_time(self) -> None:
  731. if self.pytest_plugin and self.pytest_nodeid:
  732. self.pytest_plugin.update_process_status(
  733. self.pytest_nodeid, {"start_time": time.time()}
  734. )
  735. def end_time(self) -> None:
  736. if self.pytest_plugin and self.pytest_nodeid:
  737. self.pytest_plugin.update_process_status(
  738. self.pytest_nodeid, {"end_time": time.time()}
  739. )
  740. def end_run(self, run_tree, outputs: Any) -> None:
  741. if not (outputs is None or isinstance(outputs, dict)):
  742. outputs = {"output": outputs}
  743. example_id = self.example_id or _get_example_id(
  744. dataset_id=str(self.test_suite.id),
  745. inputs=self.inputs or {},
  746. outputs=outputs,
  747. )
  748. self.test_suite.end_run(
  749. run_tree,
  750. example_id,
  751. outputs,
  752. reference_outputs=self._logged_reference_outputs,
  753. metadata=self.metadata,
  754. split=self.split,
  755. pytest_plugin=self.pytest_plugin,
  756. pytest_nodeid=self.pytest_nodeid,
  757. )
  758. _TEST_CASE = contextvars.ContextVar[Optional[_TestCase]]("_TEST_CASE", default=None)
  759. class _UTExtra(TypedDict, total=False):
  760. client: Optional[ls_client.Client]
  761. id: Optional[uuid.UUID]
  762. output_keys: Optional[Sequence[str]]
  763. test_suite_name: Optional[str]
  764. cache: Optional[str]
  765. metadata: Optional[dict]
  766. repetitions: Optional[int]
  767. split: Optional[Union[str, list[str]]]
  768. cached_hosts: Optional[Sequence[str]]
  769. def _create_test_case(
  770. func: Callable,
  771. *args: Any,
  772. pytest_request: Any,
  773. langtest_extra: _UTExtra,
  774. **kwargs: Any,
  775. ) -> _TestCase:
  776. client = langtest_extra["client"] or rt.get_cached_client()
  777. output_keys = langtest_extra["output_keys"]
  778. metadata = langtest_extra["metadata"]
  779. split = langtest_extra["split"]
  780. signature = inspect.signature(func)
  781. inputs = rh._get_inputs_safe(signature, *args, **kwargs) or None
  782. outputs = None
  783. if output_keys:
  784. outputs = {}
  785. if not inputs:
  786. msg = (
  787. "'output_keys' should only be specified when marked test function has "
  788. "input arguments."
  789. )
  790. raise ValueError(msg)
  791. for k in output_keys:
  792. outputs[k] = inputs.pop(k, None)
  793. test_suite = _LangSmithTestSuite.from_test(
  794. client, func, langtest_extra.get("test_suite_name")
  795. )
  796. example_id = langtest_extra["id"]
  797. dataset_sdk_version = (
  798. test_suite._dataset.metadata
  799. and test_suite._dataset.metadata.get("runtime")
  800. and test_suite._dataset.metadata.get("runtime", {}).get("sdk_version")
  801. )
  802. if not dataset_sdk_version or not ls_utils.is_version_greater_or_equal(
  803. dataset_sdk_version, "0.4.33"
  804. ):
  805. legacy_example_id, example_name = _get_example_id_legacy(
  806. func, inputs, test_suite.id
  807. )
  808. example_id = example_id or legacy_example_id
  809. pytest_plugin = (
  810. pytest_request.config.pluginmanager.get_plugin("langsmith_output_plugin")
  811. if pytest_request
  812. else None
  813. )
  814. pytest_nodeid = pytest_request.node.nodeid if pytest_request else None
  815. if pytest_plugin:
  816. pytest_plugin.test_suite_urls[test_suite._dataset.name] = (
  817. cast(str, test_suite._dataset.url)
  818. + "/compare?selectedSessions="
  819. + str(test_suite.experiment_id)
  820. )
  821. test_case = _TestCase(
  822. test_suite,
  823. run_id=uuid.uuid4(),
  824. example_id=example_id,
  825. metadata=metadata,
  826. split=split,
  827. inputs=inputs,
  828. reference_outputs=outputs,
  829. pytest_plugin=pytest_plugin,
  830. pytest_nodeid=pytest_nodeid,
  831. )
  832. return test_case
  833. def _run_test(
  834. func: Callable,
  835. *test_args: Any,
  836. pytest_request: Any,
  837. langtest_extra: _UTExtra,
  838. **test_kwargs: Any,
  839. ) -> None:
  840. test_case = _create_test_case(
  841. func,
  842. *test_args,
  843. **test_kwargs,
  844. pytest_request=pytest_request,
  845. langtest_extra=langtest_extra,
  846. )
  847. _TEST_CASE.set(test_case)
  848. def _test():
  849. test_case.start_time()
  850. with rh.trace(
  851. name=getattr(func, "__name__", "Test"),
  852. run_id=test_case.run_id,
  853. inputs=test_case.inputs,
  854. metadata={
  855. # Experiment run metadata is prefixed with "ls_example_" in
  856. # the ingest backend, but we must reproduce this behavior here
  857. # because the example may not have been created before the trace
  858. # starts.
  859. f"ls_example_{k}": v
  860. for k, v in (test_case.metadata or {}).items()
  861. },
  862. project_name=test_case.test_suite.name,
  863. exceptions_to_handle=(SkipException,),
  864. _end_on_exit=False,
  865. ) as run_tree:
  866. try:
  867. result = func(*test_args, **test_kwargs)
  868. except SkipException as e:
  869. test_case.submit_test_result(error=repr(e), skipped=True)
  870. test_case.end_run(run_tree, {"skipped_reason": repr(e)})
  871. raise e
  872. except BaseException as e:
  873. test_case.submit_test_result(error=repr(e))
  874. test_case.end_run(run_tree, None)
  875. raise e
  876. else:
  877. test_case.end_run(run_tree, result)
  878. finally:
  879. test_case.end_time()
  880. try:
  881. test_case.submit_test_result()
  882. except BaseException as e:
  883. logger.warning(
  884. f"Failed to create feedback for run_id {test_case.run_id}:\n{e}"
  885. )
  886. if langtest_extra["cache"]:
  887. cache_path = Path(langtest_extra["cache"]) / f"{test_case.test_suite.id}.yaml"
  888. else:
  889. cache_path = None
  890. current_context = rh.get_tracing_context()
  891. metadata = {
  892. **(current_context["metadata"] or {}),
  893. **{
  894. "experiment": test_case.test_suite.experiment.name,
  895. },
  896. }
  897. # Handle cached_hosts parameter
  898. ignore_hosts = [test_case.test_suite.client.api_url]
  899. allow_hosts = langtest_extra.get("cached_hosts") or None
  900. with (
  901. rh.tracing_context(**{**current_context, "metadata": metadata}),
  902. ls_utils.with_optional_cache(
  903. cache_path, ignore_hosts=ignore_hosts, allow_hosts=allow_hosts
  904. ),
  905. ):
  906. _test()
  907. async def _arun_test(
  908. func: Callable,
  909. *test_args: Any,
  910. pytest_request: Any,
  911. langtest_extra: _UTExtra,
  912. **test_kwargs: Any,
  913. ) -> None:
  914. test_case = _create_test_case(
  915. func,
  916. *test_args,
  917. **test_kwargs,
  918. pytest_request=pytest_request,
  919. langtest_extra=langtest_extra,
  920. )
  921. _TEST_CASE.set(test_case)
  922. async def _test():
  923. test_case.start_time()
  924. with rh.trace(
  925. name=getattr(func, "__name__", "Test"),
  926. run_id=test_case.run_id,
  927. reference_example_id=test_case.example_id,
  928. inputs=test_case.inputs,
  929. metadata={
  930. # Experiment run metadata is prefixed with "ls_example_" in
  931. # the ingest backend, but we must reproduce this behavior here
  932. # because the example may not have been created before the trace
  933. # starts.
  934. f"ls_example_{k}": v
  935. for k, v in (test_case.metadata or {}).items()
  936. },
  937. project_name=test_case.test_suite.name,
  938. exceptions_to_handle=(SkipException,),
  939. _end_on_exit=False,
  940. ) as run_tree:
  941. try:
  942. result = await func(*test_args, **test_kwargs)
  943. except SkipException as e:
  944. test_case.submit_test_result(error=repr(e), skipped=True)
  945. test_case.end_run(run_tree, {"skipped_reason": repr(e)})
  946. raise e
  947. except BaseException as e:
  948. test_case.submit_test_result(error=repr(e))
  949. test_case.end_run(run_tree, None)
  950. raise e
  951. else:
  952. test_case.end_run(run_tree, result)
  953. finally:
  954. test_case.end_time()
  955. try:
  956. test_case.submit_test_result()
  957. except BaseException as e:
  958. logger.warning(
  959. f"Failed to create feedback for run_id {test_case.run_id}:\n{e}"
  960. )
  961. if langtest_extra["cache"]:
  962. cache_path = Path(langtest_extra["cache"]) / f"{test_case.test_suite.id}.yaml"
  963. else:
  964. cache_path = None
  965. current_context = rh.get_tracing_context()
  966. metadata = {
  967. **(current_context["metadata"] or {}),
  968. **{
  969. "experiment": test_case.test_suite.experiment.name,
  970. "reference_example_id": str(test_case.example_id),
  971. },
  972. }
  973. # Handle cached_hosts parameter
  974. ignore_hosts = [test_case.test_suite.client.api_url]
  975. cached_hosts = langtest_extra.get("cached_hosts")
  976. allow_hosts = cached_hosts if cached_hosts else None
  977. with (
  978. rh.tracing_context(**{**current_context, "metadata": metadata}),
  979. ls_utils.with_optional_cache(
  980. cache_path, ignore_hosts=ignore_hosts, allow_hosts=allow_hosts
  981. ),
  982. ):
  983. await _test()
  984. # For backwards compatibility
  985. unit = test
  986. def log_inputs(inputs: dict, /) -> None:
  987. """Log run inputs from within a pytest test run.
  988. !!! warning
  989. This API is in beta and might change in future versions.
  990. Should only be used in pytest tests decorated with @pytest.mark.langsmith.
  991. Args:
  992. inputs: Inputs to log.
  993. Example:
  994. ```python
  995. from langsmith import testing as t
  996. @pytest.mark.langsmith
  997. def test_foo() -> None:
  998. x = 0
  999. y = 1
  1000. t.log_inputs({"x": x, "y": y})
  1001. assert foo(x, y) == 2
  1002. ```
  1003. """
  1004. if ls_utils.test_tracking_is_disabled():
  1005. logger.info("LANGSMITH_TEST_TRACKING is set to 'false'. Skipping log_inputs.")
  1006. return
  1007. run_tree = rh.get_current_run_tree()
  1008. test_case = _TEST_CASE.get()
  1009. if not run_tree or not test_case:
  1010. msg = (
  1011. "log_inputs should only be called within a pytest test decorated with "
  1012. "@pytest.mark.langsmith, and with tracing enabled (by setting the "
  1013. "LANGSMITH_TRACING environment variable to 'true')."
  1014. )
  1015. raise ValueError(msg)
  1016. run_tree.add_inputs(inputs)
  1017. test_case.log_inputs(inputs)
  1018. def log_outputs(outputs: dict, /) -> None:
  1019. """Log run outputs from within a pytest test run.
  1020. !!! warning
  1021. This API is in beta and might change in future versions.
  1022. Should only be used in pytest tests decorated with @pytest.mark.langsmith.
  1023. Args:
  1024. outputs: Outputs to log.
  1025. Example:
  1026. ```python
  1027. from langsmith import testing as t
  1028. @pytest.mark.langsmith
  1029. def test_foo() -> None:
  1030. x = 0
  1031. y = 1
  1032. result = foo(x, y)
  1033. t.log_outputs({"foo": result})
  1034. assert result == 2
  1035. ```
  1036. """
  1037. if ls_utils.test_tracking_is_disabled():
  1038. logger.info("LANGSMITH_TEST_TRACKING is set to 'false'. Skipping log_outputs.")
  1039. return
  1040. run_tree = rh.get_current_run_tree()
  1041. test_case = _TEST_CASE.get()
  1042. if not run_tree or not test_case:
  1043. msg = (
  1044. "log_outputs should only be called within a pytest test decorated with "
  1045. "@pytest.mark.langsmith, and with tracing enabled (by setting the "
  1046. "LANGSMITH_TRACING environment variable to 'true')."
  1047. )
  1048. raise ValueError(msg)
  1049. outputs = _dumpd(outputs)
  1050. run_tree.add_outputs(outputs)
  1051. test_case.log_outputs(outputs)
  1052. def log_reference_outputs(reference_outputs: dict, /) -> None:
  1053. """Log example reference outputs from within a pytest test run.
  1054. !!! warning
  1055. This API is in beta and might change in future versions.
  1056. Should only be used in pytest tests decorated with @pytest.mark.langsmith.
  1057. Args:
  1058. reference_outputs: Reference outputs to log.
  1059. Example:
  1060. ```python
  1061. from langsmith import testing
  1062. @pytest.mark.langsmith
  1063. def test_foo() -> None:
  1064. x = 0
  1065. y = 1
  1066. expected = 2
  1067. testing.log_reference_outputs({"foo": expected})
  1068. assert foo(x, y) == expected
  1069. ```
  1070. """
  1071. if ls_utils.test_tracking_is_disabled():
  1072. logger.info(
  1073. "LANGSMITH_TEST_TRACKING is set to 'false'. Skipping log_reference_outputs."
  1074. )
  1075. return
  1076. test_case = _TEST_CASE.get()
  1077. if not test_case:
  1078. msg = (
  1079. "log_reference_outputs should only be called within a pytest test "
  1080. "decorated with @pytest.mark.langsmith."
  1081. )
  1082. raise ValueError(msg)
  1083. test_case.log_reference_outputs(reference_outputs)
  1084. def log_feedback(
  1085. feedback: Optional[Union[dict, list[dict]]] = None,
  1086. /,
  1087. *,
  1088. key: str,
  1089. score: Optional[Union[int, bool, float]] = None,
  1090. value: Optional[Union[str, int, float, bool]] = None,
  1091. **kwargs: Any,
  1092. ) -> None:
  1093. """Log run feedback from within a pytest test run.
  1094. !!! warning
  1095. This API is in beta and might change in future versions.
  1096. Should only be used in pytest tests decorated with @pytest.mark.langsmith.
  1097. Args:
  1098. key: Feedback name.
  1099. score: Numerical feedback value.
  1100. value: Categorical feedback value
  1101. kwargs: Any other Client.create_feedback args.
  1102. Example:
  1103. ```python
  1104. import pytest
  1105. from langsmith import testing as t
  1106. @pytest.mark.langsmith
  1107. def test_foo() -> None:
  1108. x = 0
  1109. y = 1
  1110. expected = 2
  1111. result = foo(x, y)
  1112. t.log_feedback(key="right_type", score=isinstance(result, int))
  1113. assert result == expected
  1114. ```
  1115. """
  1116. if ls_utils.test_tracking_is_disabled():
  1117. logger.info("LANGSMITH_TEST_TRACKING is set to 'false'. Skipping log_feedback.")
  1118. return
  1119. if feedback and any((key, score, value)):
  1120. msg = "Must specify one of 'feedback' and ('key', 'score', 'value'), not both."
  1121. raise ValueError(msg)
  1122. elif not (feedback or key):
  1123. msg = "Must specify at least one of 'feedback' or ('key', 'score', value')."
  1124. raise ValueError(msg)
  1125. elif key:
  1126. feedback = {"key": key}
  1127. if score is not None:
  1128. feedback["score"] = score
  1129. if value is not None:
  1130. feedback["value"] = value
  1131. else:
  1132. pass
  1133. run_tree = rh.get_current_run_tree()
  1134. test_case = _TEST_CASE.get()
  1135. if not run_tree or not test_case:
  1136. msg = (
  1137. "log_feedback should only be called within a pytest test decorated with "
  1138. "@pytest.mark.langsmith, and with tracing enabled (by setting the "
  1139. "LANGSMITH_TRACING environment variable to 'true')."
  1140. )
  1141. raise ValueError(msg)
  1142. if run_tree.session_name == "evaluators" and run_tree.metadata.get(
  1143. "reference_run_id"
  1144. ):
  1145. run_id = run_tree.metadata["reference_run_id"]
  1146. run_tree.add_outputs(
  1147. feedback if isinstance(feedback, dict) else {"feedback": feedback}
  1148. )
  1149. kwargs["source_run_id"] = run_tree.id
  1150. else:
  1151. run_id = run_tree.trace_id
  1152. test_case.submit_feedback(run_id, cast(Union[list, dict], feedback), **kwargs)
  1153. @contextlib.contextmanager
  1154. def trace_feedback(
  1155. *, name: str = "Feedback"
  1156. ) -> Generator[Optional[run_trees.RunTree], None, None]:
  1157. """Trace the computation of a pytest run feedback as its own run.
  1158. !!! warning
  1159. This API is in beta and might change in future versions.
  1160. Args:
  1161. name: Feedback run name. Defaults to "Feedback".
  1162. Example:
  1163. ```python
  1164. import openai
  1165. import pytest
  1166. from langsmith import testing as t
  1167. from langsmith import wrappers
  1168. oai_client = wrappers.wrap_openai(openai.Client())
  1169. @pytest.mark.langsmith
  1170. def test_openai_says_hello():
  1171. # Traced code will be included in the test case
  1172. text = "Say hello!"
  1173. response = oai_client.chat.completions.create(
  1174. model="gpt-4o-mini",
  1175. messages=[
  1176. {"role": "system", "content": "You are a helpful assistant."},
  1177. {"role": "user", "content": text},
  1178. ],
  1179. )
  1180. t.log_inputs({"text": text})
  1181. t.log_outputs({"response": response.choices[0].message.content})
  1182. t.log_reference_outputs({"response": "hello!"})
  1183. # Use this context manager to trace any steps used for generating evaluation
  1184. # feedback separately from the main application logic
  1185. with t.trace_feedback():
  1186. grade = oai_client.chat.completions.create(
  1187. model="gpt-4o-mini",
  1188. messages=[
  1189. {
  1190. "role": "system",
  1191. "content": "Return 1 if 'hello' is in the user message and 0 otherwise.",
  1192. },
  1193. {
  1194. "role": "user",
  1195. "content": response.choices[0].message.content,
  1196. },
  1197. ],
  1198. )
  1199. # Make sure to log relevant feedback within the context for the
  1200. # trace to be associated with this feedback.
  1201. t.log_feedback(
  1202. key="llm_judge", score=float(grade.choices[0].message.content)
  1203. )
  1204. assert "hello" in response.choices[0].message.content.lower()
  1205. ```
  1206. """ # noqa: E501
  1207. if ls_utils.test_tracking_is_disabled():
  1208. logger.info("LANGSMITH_TEST_TRACKING is set to 'false'. Skipping log_feedback.")
  1209. yield None
  1210. return
  1211. test_case = _TEST_CASE.get()
  1212. if not test_case:
  1213. msg = (
  1214. "trace_feedback should only be called within a pytest test decorated with "
  1215. "@pytest.mark.langsmith, and with tracing enabled (by setting the "
  1216. "LANGSMITH_TRACING environment variable to 'true')."
  1217. )
  1218. raise ValueError(msg)
  1219. metadata = {
  1220. "experiment": test_case.test_suite.experiment.name,
  1221. "reference_example_id": test_case.example_id,
  1222. "reference_run_id": test_case.run_id,
  1223. }
  1224. with rh.trace(
  1225. name=name,
  1226. inputs=test_case._logged_outputs,
  1227. parent="ignore",
  1228. project_name="evaluators",
  1229. metadata=metadata,
  1230. ) as run_tree:
  1231. yield run_tree
  1232. def _stringify(x: Any) -> str:
  1233. try:
  1234. return dumps_json(x).decode("utf-8", errors="surrogateescape")
  1235. except Exception:
  1236. return str(x)
  1237. def _dumpd(x: Any) -> Any:
  1238. """Serialize LangChain Serializable objects."""
  1239. dumpd = _get_langchain_dumpd()
  1240. if not dumpd:
  1241. return x
  1242. try:
  1243. serialized = dumpd(x)
  1244. return serialized
  1245. except Exception:
  1246. return x
  1247. @functools.lru_cache
  1248. def _get_langchain_dumpd() -> Optional[Callable]:
  1249. try:
  1250. from langchain_core.load import dumpd
  1251. return dumpd
  1252. except ImportError:
  1253. return None