_expect.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. """Make approximate assertions as "expectations" on test results.
  2. This module is designed to be used within test cases decorated with the
  3. `@pytest.mark.decorator` decorator
  4. It allows you to log scores about a test case and optionally make assertions that log as
  5. "expectation" feedback to LangSmith.
  6. Example:
  7. ```python
  8. import pytest
  9. from langsmith import expect
  10. @pytest.mark.langsmith
  11. def test_output_semantically_close():
  12. response = oai_client.chat.completions.create(
  13. model="gpt-3.5-turbo",
  14. messages=[
  15. {"role": "system", "content": "You are a helpful assistant."},
  16. {"role": "user", "content": "Say hello!"},
  17. ],
  18. )
  19. response_txt = response.choices[0].message.content
  20. # Intended usage
  21. expect.embedding_distance(
  22. prediction=response_txt,
  23. reference="Hello!",
  24. ).to_be_less_than(0.9)
  25. # Score the test case
  26. matcher = expect.edit_distance(
  27. prediction=response_txt,
  28. reference="Hello!",
  29. )
  30. # Apply an assertion and log 'expectation' feedback to LangSmith
  31. matcher.to_be_less_than(1)
  32. # You can also directly make assertions on values directly
  33. expect.value(response_txt).to_contain("Hello!")
  34. # Or using a custom check
  35. expect.value(response_txt).against(lambda x: "Hello" in x)
  36. # You can even use this for basic metric logging within tests
  37. expect.score(0.8)
  38. expect.score(0.7, key="similarity").to_be_greater_than(0.7)
  39. ```
  40. """ # noqa: E501
  41. from __future__ import annotations
  42. import atexit
  43. import inspect
  44. from typing import (
  45. TYPE_CHECKING,
  46. Any,
  47. Callable,
  48. Literal,
  49. Optional,
  50. Union,
  51. overload,
  52. )
  53. from langsmith import client as ls_client
  54. from langsmith import run_helpers as rh
  55. from langsmith import run_trees as rt
  56. from langsmith import utils as ls_utils
  57. if TYPE_CHECKING:
  58. from langsmith._internal._edit_distance import EditDistanceConfig
  59. from langsmith._internal._embedding_distance import EmbeddingConfig
  60. # Sentinel class used until PEP 0661 is accepted
  61. class _NULL_SENTRY:
  62. """A sentinel singleton class used to distinguish omitted keyword arguments
  63. from those passed in with the value None (which may have different behavior).
  64. """ # noqa: D205
  65. def __bool__(self) -> Literal[False]:
  66. return False
  67. def __repr__(self) -> str:
  68. return "NOT_GIVEN"
  69. NOT_GIVEN = _NULL_SENTRY()
  70. class _Matcher:
  71. """A class for making assertions on expectation values."""
  72. def __init__(
  73. self,
  74. client: Optional[ls_client.Client],
  75. key: str,
  76. value: Any,
  77. _executor: Optional[ls_utils.ContextThreadPoolExecutor] = None,
  78. run_id: Optional[str] = None,
  79. ):
  80. self._client = client
  81. self.key = key
  82. self.value = value
  83. self._executor = _executor or ls_utils.ContextThreadPoolExecutor(max_workers=3)
  84. rt = rh.get_current_run_tree()
  85. self._run_id = rt.trace_id if rt else run_id
  86. def _submit_feedback(self, score: int, message: Optional[str] = None) -> None:
  87. if not ls_utils.test_tracking_is_disabled():
  88. if not self._client:
  89. self._client = rt.get_cached_client()
  90. self._executor.submit(
  91. self._client.create_feedback,
  92. run_id=self._run_id,
  93. key="expectation",
  94. score=score,
  95. comment=message,
  96. )
  97. def _assert(self, condition: bool, message: str, method_name: str) -> None:
  98. try:
  99. assert condition, message
  100. self._submit_feedback(1, message=f"Success: {self.key}.{method_name}")
  101. except AssertionError as e:
  102. self._submit_feedback(0, repr(e))
  103. raise e from None
  104. def to_be_less_than(self, value: float) -> None:
  105. """Assert that the expectation value is less than the given value.
  106. Args:
  107. value: The value to compare against.
  108. Raises:
  109. AssertionError: If the expectation value is not less than the given value.
  110. """
  111. self._assert(
  112. self.value < value,
  113. f"Expected {self.key} to be less than {value}, but got {self.value}",
  114. "to_be_less_than",
  115. )
  116. def to_be_greater_than(self, value: float) -> None:
  117. """Assert that the expectation value is greater than the given value.
  118. Args:
  119. value: The value to compare against.
  120. Raises:
  121. AssertionError: If the expectation value is not
  122. greater than the given value.
  123. """
  124. self._assert(
  125. self.value > value,
  126. f"Expected {self.key} to be greater than {value}, but got {self.value}",
  127. "to_be_greater_than",
  128. )
  129. def to_be_between(self, min_value: float, max_value: float) -> None:
  130. """Assert that the expectation value is between the given min and max values.
  131. Args:
  132. min_value: The minimum value (exclusive).
  133. max_value: The maximum value (exclusive).
  134. Raises:
  135. AssertionError: If the expectation value is not between the min and max.
  136. """
  137. self._assert(
  138. min_value < self.value < max_value,
  139. f"Expected {self.key} to be between {min_value} and {max_value},"
  140. f" but got {self.value}",
  141. "to_be_between",
  142. )
  143. def to_be_approximately(self, value: float, precision: int = 2) -> None:
  144. """Assert that the expectation value is approximately equal to the given value.
  145. Args:
  146. value: The value to compare against.
  147. precision: The number of decimal places to round to for comparison.
  148. Raises:
  149. AssertionError: If the rounded expectation value
  150. does not equal the rounded given value.
  151. """
  152. self._assert(
  153. round(self.value, precision) == round(value, precision),
  154. f"Expected {self.key} to be approximately {value}, but got {self.value}",
  155. "to_be_approximately",
  156. )
  157. def to_equal(self, value: float) -> None:
  158. """Assert that the expectation value equals the given value.
  159. Args:
  160. value: The value to compare against.
  161. Raises:
  162. AssertionError: If the expectation value does
  163. not exactly equal the given value.
  164. """
  165. self._assert(
  166. self.value == value,
  167. f"Expected {self.key} to be equal to {value}, but got {self.value}",
  168. "to_equal",
  169. )
  170. def to_be_none(self) -> None:
  171. """Assert that the expectation value is `None`.
  172. Raises:
  173. AssertionError: If the expectation value is not `None`.
  174. """
  175. self._assert(
  176. self.value is None,
  177. f"Expected {self.key} to be None, but got {self.value}",
  178. "to_be_none",
  179. )
  180. def to_contain(self, value: Any) -> None:
  181. """Assert that the expectation value contains the given value.
  182. Args:
  183. value: The value to check for containment.
  184. Raises:
  185. AssertionError: If the expectation value does not contain the given value.
  186. """
  187. self._assert(
  188. value in self.value,
  189. f"Expected {self.key} to contain {value}, but it does not",
  190. "to_contain",
  191. )
  192. # Custom assertions
  193. def against(self, func: Callable, /) -> None:
  194. """Assert the expectation value against a custom function.
  195. Args:
  196. func: A custom function that takes the expectation value as input.
  197. Raises:
  198. AssertionError: If the custom function returns False.
  199. """
  200. func_signature = inspect.signature(func)
  201. self._assert(
  202. func(self.value),
  203. f"Assertion {func_signature} failed for {self.key}",
  204. "against",
  205. )
  206. class _Expect:
  207. """A class for setting expectations on test results."""
  208. def __init__(self, *, client: Optional[ls_client.Client] = None):
  209. self._client = client
  210. self.executor = ls_utils.ContextThreadPoolExecutor(max_workers=3)
  211. atexit.register(self.executor.shutdown, wait=True)
  212. def embedding_distance(
  213. self,
  214. prediction: str,
  215. reference: str,
  216. *,
  217. config: Optional[EmbeddingConfig] = None,
  218. ) -> _Matcher:
  219. """Compute the embedding distance between the prediction and reference.
  220. This logs the embedding distance to LangSmith and returns a `_Matcher` instance
  221. for making assertions on the distance value.
  222. By default, this uses the OpenAI API for computing embeddings.
  223. Args:
  224. prediction: The predicted string to compare.
  225. reference: The reference string to compare against.
  226. config: Optional configuration for the embedding distance evaluator.
  227. Supported options:
  228. - `encoder`: A custom encoder function to encode the list of input
  229. strings to embeddings.
  230. Defaults to the OpenAI API.
  231. - `metric`: The distance metric to use for comparison.
  232. Supported values: `'cosine'`, `'euclidean'`, `'manhattan'`,
  233. `'chebyshev'`, `'hamming'`.
  234. Returns:
  235. A `_Matcher` instance for the embedding distance value.
  236. Example:
  237. ```python
  238. expect.embedding_distance(
  239. prediction="hello",
  240. reference="hi",
  241. ).to_be_less_than(1.0)
  242. ```
  243. """ # noqa: E501
  244. from langsmith._internal._embedding_distance import EmbeddingDistance
  245. config = config or {}
  246. encoder_func = "custom" if config.get("encoder") else "openai"
  247. evaluator = EmbeddingDistance(config=config)
  248. score = evaluator.evaluate(prediction=prediction, reference=reference)
  249. src_info = {"encoder": encoder_func, "metric": evaluator.distance}
  250. self._submit_feedback(
  251. "embedding_distance",
  252. {
  253. "score": score,
  254. "source_info": src_info,
  255. "comment": f"Using {encoder_func}, Metric: {evaluator.distance}",
  256. },
  257. )
  258. return _Matcher(
  259. self._client, "embedding_distance", score, _executor=self.executor
  260. )
  261. def edit_distance(
  262. self,
  263. prediction: str,
  264. reference: str,
  265. *,
  266. config: Optional[EditDistanceConfig] = None,
  267. ) -> _Matcher:
  268. """Compute the string distance between the prediction and reference.
  269. This logs the string distance (Damerau-Levenshtein) to LangSmith and returns
  270. a `_Matcher` instance for making assertions on the distance value.
  271. This depends on the `rapidfuzz` package for string distance computation.
  272. Args:
  273. prediction: The predicted string to compare.
  274. reference: The reference string to compare against.
  275. config: Optional configuration for the string distance evaluator.
  276. Supported options:
  277. - `metric`: The distance metric to use for comparison.
  278. Supported values: `'damerau_levenshtein'`, `'levenshtein'`,
  279. `'jaro'`, `'jaro_winkler'`, `'hamming'`, `'indel'`.
  280. - `normalize_score`: Whether to normalize the score between `0` and `1`.
  281. Returns:
  282. A `_Matcher` instance for the string distance value.
  283. Examples:
  284. ```python
  285. expect.edit_distance("hello", "helo").to_be_less_than(1)
  286. ```
  287. """
  288. from langsmith._internal._edit_distance import EditDistance
  289. config = config or {}
  290. metric = config.get("metric") or "damerau_levenshtein"
  291. normalize = config.get("normalize_score", True)
  292. evaluator = EditDistance(config=config)
  293. score = evaluator.evaluate(prediction=prediction, reference=reference)
  294. src_info = {"metric": metric, "normalize": normalize}
  295. self._submit_feedback(
  296. "edit_distance",
  297. {
  298. "score": score,
  299. "source_info": src_info,
  300. "comment": f"Using {metric}, Normalize: {normalize}",
  301. },
  302. )
  303. return _Matcher(
  304. self._client,
  305. "edit_distance",
  306. score,
  307. _executor=self.executor,
  308. )
  309. def value(self, value: Any) -> _Matcher:
  310. """Create a `_Matcher` instance for making assertions on the given value.
  311. Args:
  312. value: The value to make assertions on.
  313. Returns:
  314. A `_Matcher` instance for the given value.
  315. Example:
  316. ```python
  317. expect.value(10).to_be_less_than(20)
  318. ```
  319. """
  320. return _Matcher(self._client, "value", value, _executor=self.executor)
  321. def score(
  322. self,
  323. score: Union[float, int, bool],
  324. *,
  325. key: str = "score",
  326. source_run_id: Optional[ls_client.ID_TYPE] = None,
  327. comment: Optional[str] = None,
  328. ) -> _Matcher:
  329. """Log a numeric score to LangSmith.
  330. Args:
  331. score: The score value to log.
  332. key: The key to use for logging the score. Defaults to `'score'`.
  333. Example:
  334. ```python
  335. expect.score(0.8) # doctest: +ELLIPSIS
  336. <langsmith._expect._Matcher object at ...>
  337. expect.score(0.8, key="similarity").to_be_greater_than(0.7)
  338. ```
  339. """
  340. self._submit_feedback(
  341. key,
  342. {
  343. "score": score,
  344. "source_info": {"method": "expect.score"},
  345. "source_run_id": source_run_id,
  346. "comment": comment,
  347. },
  348. )
  349. return _Matcher(self._client, key, score, _executor=self.executor)
  350. ## Private Methods
  351. @overload
  352. def __call__(self, value: Any, /) -> _Matcher: ...
  353. @overload
  354. def __call__(self, /, *, client: ls_client.Client) -> _Expect: ...
  355. def __call__(
  356. self,
  357. value: Optional[Any] = NOT_GIVEN,
  358. /,
  359. client: Optional[ls_client.Client] = None,
  360. ) -> Union[_Expect, _Matcher]:
  361. expected = _Expect(client=client)
  362. if value is not NOT_GIVEN:
  363. return expected.value(value)
  364. return expected
  365. def _submit_feedback(self, key: str, results: dict):
  366. current_run = rh.get_current_run_tree()
  367. run_id = current_run.trace_id if current_run else None
  368. if not ls_utils.test_tracking_is_disabled():
  369. if not self._client:
  370. self._client = rt.get_cached_client()
  371. self.executor.submit(
  372. self._client.create_feedback, run_id=run_id, key=key, **results
  373. )
  374. expect = _Expect()
  375. __all__ = ["expect"]