pytest_plugin.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. """LangSmith Pytest hooks."""
  2. import importlib.util
  3. import json
  4. import logging
  5. import os
  6. import time
  7. from collections import defaultdict
  8. from threading import Lock
  9. from typing import Any
  10. import pytest
  11. from langsmith import utils as ls_utils
  12. from langsmith.testing._internal import test as ls_test
  13. logger = logging.getLogger(__name__)
  14. def pytest_addoption(parser):
  15. """Set a boolean flag for LangSmith output.
  16. Skip if --langsmith-output is already defined.
  17. """
  18. try:
  19. # Try to add the option, will raise if it already exists
  20. group = parser.getgroup("langsmith", "LangSmith")
  21. group.addoption(
  22. "--langsmith-output",
  23. action="store_true",
  24. default=False,
  25. help="Use LangSmith output (requires 'rich').",
  26. )
  27. except ValueError:
  28. # Option already exists
  29. logger.warning(
  30. "LangSmith output flag cannot be added because it's already defined."
  31. )
  32. def _handle_output_args(args):
  33. """Handle output arguments."""
  34. if any(opt in args for opt in ["--langsmith-output"]):
  35. # Only add --quiet if it's not already there
  36. if not any(a in args for a in ["-qq"]):
  37. args.insert(0, "-qq")
  38. # Disable built-in output capturing
  39. if not any(a in args for a in ["-s", "--capture=no"]):
  40. args.insert(0, "-s")
  41. if pytest.__version__.startswith("7."):
  42. def pytest_cmdline_preparse(config, args):
  43. """Call immediately after command line options are parsed (pytest v7)."""
  44. _handle_output_args(args)
  45. else:
  46. def pytest_load_initial_conftests(args):
  47. """Handle args in pytest v8+."""
  48. _handle_output_args(args)
  49. @pytest.hookimpl(hookwrapper=True)
  50. def pytest_runtest_call(item):
  51. """Apply LangSmith tracking to tests marked with @pytest.mark.langsmith."""
  52. marker = item.get_closest_marker("langsmith")
  53. if marker:
  54. # Get marker kwargs if any (e.g.,
  55. # @pytest.mark.langsmith(output_keys=["expected"]))
  56. kwargs = marker.kwargs if marker else {}
  57. # Wrap the test function with our test decorator
  58. original_func = item.obj
  59. item.obj = ls_test(**kwargs)(original_func)
  60. request_obj = getattr(item, "_request", None)
  61. if request_obj is not None and "request" not in item.funcargs:
  62. item.funcargs["request"] = request_obj
  63. if request_obj is not None and "request" not in item._fixtureinfo.argnames:
  64. # Create a new FuncFixtureInfo instance with updated argnames
  65. item._fixtureinfo = type(item._fixtureinfo)(
  66. argnames=item._fixtureinfo.argnames + ("request",),
  67. initialnames=item._fixtureinfo.initialnames,
  68. names_closure=item._fixtureinfo.names_closure,
  69. name2fixturedefs=item._fixtureinfo.name2fixturedefs,
  70. )
  71. yield
  72. @pytest.hookimpl
  73. def pytest_report_teststatus(report, config):
  74. """Remove the short test-status character outputs ("./F")."""
  75. # The hook normally returns a 3-tuple: (short_letter, verbose_word, color)
  76. # By returning empty strings, the progress characters won't show.
  77. if config.getoption("--langsmith-output"):
  78. return "", "", ""
  79. class LangSmithPlugin:
  80. """Plugin for rendering LangSmith results."""
  81. def __init__(self):
  82. """Initialize."""
  83. from rich.console import Console # type: ignore[import-not-found]
  84. from rich.live import Live # type: ignore[import-not-found]
  85. self.test_suites = defaultdict(list)
  86. self.test_suite_urls = {}
  87. self.process_status = {} # Track process status
  88. self.status_lock = Lock() # Thread-safe updates
  89. self.console = Console()
  90. self.live = Live(
  91. self.generate_tables(), console=self.console, refresh_per_second=10
  92. )
  93. self.live.start()
  94. self.live.console.print("Collecting tests...")
  95. def pytest_collection_finish(self, session):
  96. """Call after collection phase is completed and session.items is populated."""
  97. self.collected_nodeids = set()
  98. for item in session.items:
  99. self.collected_nodeids.add(item.nodeid)
  100. def add_process_to_test_suite(self, test_suite, process_id):
  101. """Group a test case with its test suite."""
  102. self.test_suites[test_suite].append(process_id)
  103. def update_process_status(self, process_id, status):
  104. """Update test results."""
  105. # First update
  106. if not self.process_status:
  107. self.live.console.print("Running tests...")
  108. with self.status_lock:
  109. current_status = self.process_status.get(process_id, {})
  110. self.process_status[process_id] = _merge_statuses(
  111. status,
  112. current_status,
  113. unpack=["feedback", "inputs", "reference_outputs", "outputs"],
  114. )
  115. self.live.update(self.generate_tables())
  116. def pytest_runtest_logstart(self, nodeid):
  117. """Initialize live display when first test starts."""
  118. self.update_process_status(nodeid, {"status": "running"})
  119. def generate_tables(self):
  120. """Generate a collection of tables—one per suite.
  121. Returns a 'Group' object so it can be rendered simultaneously by Rich Live.
  122. """
  123. from rich.console import Group
  124. tables = []
  125. for suite_name in self.test_suites:
  126. table = self._generate_table(suite_name)
  127. tables.append(table)
  128. group = Group(*tables)
  129. return group
  130. def _generate_table(self, suite_name: str):
  131. """Generate results table."""
  132. from rich.table import Table # type: ignore[import-not-found]
  133. process_ids = self.test_suites[suite_name]
  134. title = f"""Test Suite: [bold]{suite_name}[/bold]
  135. LangSmith URL: [bright_cyan]{self.test_suite_urls[suite_name]}[/bright_cyan]""" # noqa: E501
  136. table = Table(title=title, title_justify="left")
  137. table.add_column("Test")
  138. table.add_column("Inputs")
  139. table.add_column("Ref outputs")
  140. table.add_column("Outputs")
  141. table.add_column("Status")
  142. table.add_column("Feedback")
  143. table.add_column("Duration")
  144. # Test, inputs, ref outputs, outputs col width
  145. max_status = len("status")
  146. max_duration = len("duration")
  147. now = time.time()
  148. durations = []
  149. numeric_feedbacks = defaultdict(list)
  150. # Gather data only for this suite
  151. suite_statuses = {pid: self.process_status[pid] for pid in process_ids}
  152. for pid, status in suite_statuses.items():
  153. duration = status.get("end_time", now) - status.get("start_time", now)
  154. durations.append(duration)
  155. for k, v in status.get("feedback", {}).items():
  156. if isinstance(v, (float, int, bool)):
  157. numeric_feedbacks[k].append(v)
  158. max_duration = max(len(f"{duration:.2f}s"), max_duration)
  159. max_status = max(len(status.get("status", "queued")), max_status)
  160. passed_count = sum(s.get("status") == "passed" for s in suite_statuses.values())
  161. failed_count = sum(s.get("status") == "failed" for s in suite_statuses.values())
  162. # You could arrange a row to show the aggregated data—here, in the last column:
  163. if passed_count + failed_count:
  164. rate = passed_count / (passed_count + failed_count)
  165. color = "green" if rate == 1 else "red"
  166. aggregate_status = f"[{color}]{rate:.0%}[/{color}]"
  167. else:
  168. aggregate_status = "Passed: --"
  169. if durations:
  170. aggregate_duration = f"{sum(durations) / len(durations):.2f}s"
  171. else:
  172. aggregate_duration = "--s"
  173. if numeric_feedbacks:
  174. aggregate_feedback = "\n".join(
  175. f"{k}: {sum(v) / len(v)}" for k, v in numeric_feedbacks.items()
  176. )
  177. else:
  178. aggregate_feedback = "--"
  179. max_duration = max(max_duration, len(aggregate_duration))
  180. max_dynamic_col_width = (self.console.width - (max_status + max_duration)) // 5
  181. max_dynamic_col_width = max(max_dynamic_col_width, 8)
  182. for pid, status in suite_statuses.items():
  183. status_color = {
  184. "running": "yellow",
  185. "passed": "green",
  186. "failed": "red",
  187. "skipped": "cyan",
  188. }.get(status.get("status", "queued"), "white")
  189. duration = status.get("end_time", now) - status.get("start_time", now)
  190. feedback = "\n".join(
  191. f"{_abbreviate(k, max_len=max_dynamic_col_width)}: {int(v) if isinstance(v, bool) else v}" # noqa: E501
  192. for k, v in status.get("feedback", {}).items()
  193. )
  194. inputs = _dumps_with_fallback(status.get("inputs", {}))
  195. reference_outputs = _dumps_with_fallback(
  196. status.get("reference_outputs", {})
  197. )
  198. outputs = _dumps_with_fallback(status.get("outputs", {}))
  199. table.add_row(
  200. _abbreviate_test_name(str(pid), max_len=max_dynamic_col_width),
  201. _abbreviate(inputs, max_len=max_dynamic_col_width),
  202. _abbreviate(reference_outputs, max_len=max_dynamic_col_width),
  203. _abbreviate(outputs, max_len=max_dynamic_col_width)[
  204. -max_dynamic_col_width:
  205. ],
  206. f"[{status_color}]{status.get('status', 'queued')}[/{status_color}]",
  207. feedback,
  208. f"{duration:.2f}s",
  209. )
  210. # Add a blank row or a section separator if you like:
  211. table.add_row("", "", "", "", "", "", "")
  212. # Finally, our “footer” row:
  213. table.add_row(
  214. "[bold]Averages[/bold]",
  215. "",
  216. "",
  217. "",
  218. aggregate_status,
  219. aggregate_feedback,
  220. aggregate_duration,
  221. )
  222. return table
  223. def pytest_configure(self, config):
  224. """Disable warning reporting and show no warnings in output."""
  225. # Disable general warning reporting
  226. config.option.showwarnings = False
  227. # Disable warning summary
  228. reporter = config.pluginmanager.get_plugin("warnings-plugin")
  229. if reporter:
  230. reporter.warning_summary = lambda *args, **kwargs: None
  231. def pytest_sessionfinish(self, session):
  232. """Stop Rich Live rendering at the end of the session."""
  233. self.live.stop()
  234. self.live.console.print("\nFinishing up...")
  235. def pytest_configure(config):
  236. """Register the 'langsmith' marker."""
  237. config.addinivalue_line(
  238. "markers", "langsmith: mark test to be tracked in LangSmith"
  239. )
  240. if config.getoption("--langsmith-output"):
  241. if not importlib.util.find_spec("rich"):
  242. msg = (
  243. "Must have 'rich' installed to use --langsmith-output. "
  244. "Please install with: `pip install -U 'langsmith[pytest]'`"
  245. )
  246. raise ValueError(msg)
  247. if os.environ.get("PYTEST_XDIST_TESTRUNUID"):
  248. msg = (
  249. "--langsmith-output not supported with pytest-xdist. "
  250. "Please remove the '--langsmith-output' option or '-n' option."
  251. )
  252. raise ValueError(msg)
  253. if ls_utils.test_tracking_is_disabled():
  254. msg = (
  255. "--langsmith-output not supported when env var"
  256. "LANGSMITH_TEST_TRACKING='false'. Please remove the"
  257. "'--langsmith-output' option "
  258. "or enable test tracking."
  259. )
  260. raise ValueError(msg)
  261. config.pluginmanager.register(LangSmithPlugin(), "langsmith_output_plugin")
  262. # Suppress warnings summary
  263. config.option.showwarnings = False
  264. def _abbreviate(x: str, max_len: int) -> str:
  265. if len(x) > max_len:
  266. return x[: max_len - 3] + "..."
  267. else:
  268. return x
  269. def _abbreviate_test_name(test_name: str, max_len: int) -> str:
  270. if len(test_name) > max_len:
  271. file, test = test_name.split("::")
  272. if len(".py::" + test) > max_len:
  273. return "..." + test[-(max_len - 3) :]
  274. file_len = max_len - len("...::" + test)
  275. return "..." + file[-file_len:] + "::" + test
  276. else:
  277. return test_name
  278. def _merge_statuses(update: dict, current: dict, *, unpack: list[str]) -> dict:
  279. for path in unpack:
  280. if path_update := update.pop(path, None):
  281. path_current = current.get(path, {})
  282. if isinstance(path_update, dict) and isinstance(path_current, dict):
  283. current[path] = {**path_current, **path_update}
  284. else:
  285. current[path] = path_update
  286. return {**current, **update}
  287. def _dumps_with_fallback(obj: Any) -> str:
  288. try:
  289. return json.dumps(obj)
  290. except Exception:
  291. return "unserializable"