| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347 |
- """LangSmith Pytest hooks."""
- import importlib.util
- import json
- import logging
- import os
- import time
- from collections import defaultdict
- from threading import Lock
- from typing import Any
- import pytest
- from langsmith import utils as ls_utils
- from langsmith.testing._internal import test as ls_test
- logger = logging.getLogger(__name__)
- def pytest_addoption(parser):
- """Set a boolean flag for LangSmith output.
- Skip if --langsmith-output is already defined.
- """
- try:
- # Try to add the option, will raise if it already exists
- group = parser.getgroup("langsmith", "LangSmith")
- group.addoption(
- "--langsmith-output",
- action="store_true",
- default=False,
- help="Use LangSmith output (requires 'rich').",
- )
- except ValueError:
- # Option already exists
- logger.warning(
- "LangSmith output flag cannot be added because it's already defined."
- )
- def _handle_output_args(args):
- """Handle output arguments."""
- if any(opt in args for opt in ["--langsmith-output"]):
- # Only add --quiet if it's not already there
- if not any(a in args for a in ["-qq"]):
- args.insert(0, "-qq")
- # Disable built-in output capturing
- if not any(a in args for a in ["-s", "--capture=no"]):
- args.insert(0, "-s")
- if pytest.__version__.startswith("7."):
- def pytest_cmdline_preparse(config, args):
- """Call immediately after command line options are parsed (pytest v7)."""
- _handle_output_args(args)
- else:
- def pytest_load_initial_conftests(args):
- """Handle args in pytest v8+."""
- _handle_output_args(args)
- @pytest.hookimpl(hookwrapper=True)
- def pytest_runtest_call(item):
- """Apply LangSmith tracking to tests marked with @pytest.mark.langsmith."""
- marker = item.get_closest_marker("langsmith")
- if marker:
- # Get marker kwargs if any (e.g.,
- # @pytest.mark.langsmith(output_keys=["expected"]))
- kwargs = marker.kwargs if marker else {}
- # Wrap the test function with our test decorator
- original_func = item.obj
- item.obj = ls_test(**kwargs)(original_func)
- request_obj = getattr(item, "_request", None)
- if request_obj is not None and "request" not in item.funcargs:
- item.funcargs["request"] = request_obj
- if request_obj is not None and "request" not in item._fixtureinfo.argnames:
- # Create a new FuncFixtureInfo instance with updated argnames
- item._fixtureinfo = type(item._fixtureinfo)(
- argnames=item._fixtureinfo.argnames + ("request",),
- initialnames=item._fixtureinfo.initialnames,
- names_closure=item._fixtureinfo.names_closure,
- name2fixturedefs=item._fixtureinfo.name2fixturedefs,
- )
- yield
- @pytest.hookimpl
- def pytest_report_teststatus(report, config):
- """Remove the short test-status character outputs ("./F")."""
- # The hook normally returns a 3-tuple: (short_letter, verbose_word, color)
- # By returning empty strings, the progress characters won't show.
- if config.getoption("--langsmith-output"):
- return "", "", ""
- class LangSmithPlugin:
- """Plugin for rendering LangSmith results."""
- def __init__(self):
- """Initialize."""
- from rich.console import Console # type: ignore[import-not-found]
- from rich.live import Live # type: ignore[import-not-found]
- self.test_suites = defaultdict(list)
- self.test_suite_urls = {}
- self.process_status = {} # Track process status
- self.status_lock = Lock() # Thread-safe updates
- self.console = Console()
- self.live = Live(
- self.generate_tables(), console=self.console, refresh_per_second=10
- )
- self.live.start()
- self.live.console.print("Collecting tests...")
- def pytest_collection_finish(self, session):
- """Call after collection phase is completed and session.items is populated."""
- self.collected_nodeids = set()
- for item in session.items:
- self.collected_nodeids.add(item.nodeid)
- def add_process_to_test_suite(self, test_suite, process_id):
- """Group a test case with its test suite."""
- self.test_suites[test_suite].append(process_id)
- def update_process_status(self, process_id, status):
- """Update test results."""
- # First update
- if not self.process_status:
- self.live.console.print("Running tests...")
- with self.status_lock:
- current_status = self.process_status.get(process_id, {})
- self.process_status[process_id] = _merge_statuses(
- status,
- current_status,
- unpack=["feedback", "inputs", "reference_outputs", "outputs"],
- )
- self.live.update(self.generate_tables())
- def pytest_runtest_logstart(self, nodeid):
- """Initialize live display when first test starts."""
- self.update_process_status(nodeid, {"status": "running"})
- def generate_tables(self):
- """Generate a collection of tables—one per suite.
- Returns a 'Group' object so it can be rendered simultaneously by Rich Live.
- """
- from rich.console import Group
- tables = []
- for suite_name in self.test_suites:
- table = self._generate_table(suite_name)
- tables.append(table)
- group = Group(*tables)
- return group
- def _generate_table(self, suite_name: str):
- """Generate results table."""
- from rich.table import Table # type: ignore[import-not-found]
- process_ids = self.test_suites[suite_name]
- title = f"""Test Suite: [bold]{suite_name}[/bold]
- LangSmith URL: [bright_cyan]{self.test_suite_urls[suite_name]}[/bright_cyan]""" # noqa: E501
- table = Table(title=title, title_justify="left")
- table.add_column("Test")
- table.add_column("Inputs")
- table.add_column("Ref outputs")
- table.add_column("Outputs")
- table.add_column("Status")
- table.add_column("Feedback")
- table.add_column("Duration")
- # Test, inputs, ref outputs, outputs col width
- max_status = len("status")
- max_duration = len("duration")
- now = time.time()
- durations = []
- numeric_feedbacks = defaultdict(list)
- # Gather data only for this suite
- suite_statuses = {pid: self.process_status[pid] for pid in process_ids}
- for pid, status in suite_statuses.items():
- duration = status.get("end_time", now) - status.get("start_time", now)
- durations.append(duration)
- for k, v in status.get("feedback", {}).items():
- if isinstance(v, (float, int, bool)):
- numeric_feedbacks[k].append(v)
- max_duration = max(len(f"{duration:.2f}s"), max_duration)
- max_status = max(len(status.get("status", "queued")), max_status)
- passed_count = sum(s.get("status") == "passed" for s in suite_statuses.values())
- failed_count = sum(s.get("status") == "failed" for s in suite_statuses.values())
- # You could arrange a row to show the aggregated data—here, in the last column:
- if passed_count + failed_count:
- rate = passed_count / (passed_count + failed_count)
- color = "green" if rate == 1 else "red"
- aggregate_status = f"[{color}]{rate:.0%}[/{color}]"
- else:
- aggregate_status = "Passed: --"
- if durations:
- aggregate_duration = f"{sum(durations) / len(durations):.2f}s"
- else:
- aggregate_duration = "--s"
- if numeric_feedbacks:
- aggregate_feedback = "\n".join(
- f"{k}: {sum(v) / len(v)}" for k, v in numeric_feedbacks.items()
- )
- else:
- aggregate_feedback = "--"
- max_duration = max(max_duration, len(aggregate_duration))
- max_dynamic_col_width = (self.console.width - (max_status + max_duration)) // 5
- max_dynamic_col_width = max(max_dynamic_col_width, 8)
- for pid, status in suite_statuses.items():
- status_color = {
- "running": "yellow",
- "passed": "green",
- "failed": "red",
- "skipped": "cyan",
- }.get(status.get("status", "queued"), "white")
- duration = status.get("end_time", now) - status.get("start_time", now)
- feedback = "\n".join(
- f"{_abbreviate(k, max_len=max_dynamic_col_width)}: {int(v) if isinstance(v, bool) else v}" # noqa: E501
- for k, v in status.get("feedback", {}).items()
- )
- inputs = _dumps_with_fallback(status.get("inputs", {}))
- reference_outputs = _dumps_with_fallback(
- status.get("reference_outputs", {})
- )
- outputs = _dumps_with_fallback(status.get("outputs", {}))
- table.add_row(
- _abbreviate_test_name(str(pid), max_len=max_dynamic_col_width),
- _abbreviate(inputs, max_len=max_dynamic_col_width),
- _abbreviate(reference_outputs, max_len=max_dynamic_col_width),
- _abbreviate(outputs, max_len=max_dynamic_col_width)[
- -max_dynamic_col_width:
- ],
- f"[{status_color}]{status.get('status', 'queued')}[/{status_color}]",
- feedback,
- f"{duration:.2f}s",
- )
- # Add a blank row or a section separator if you like:
- table.add_row("", "", "", "", "", "", "")
- # Finally, our “footer” row:
- table.add_row(
- "[bold]Averages[/bold]",
- "",
- "",
- "",
- aggregate_status,
- aggregate_feedback,
- aggregate_duration,
- )
- return table
- def pytest_configure(self, config):
- """Disable warning reporting and show no warnings in output."""
- # Disable general warning reporting
- config.option.showwarnings = False
- # Disable warning summary
- reporter = config.pluginmanager.get_plugin("warnings-plugin")
- if reporter:
- reporter.warning_summary = lambda *args, **kwargs: None
- def pytest_sessionfinish(self, session):
- """Stop Rich Live rendering at the end of the session."""
- self.live.stop()
- self.live.console.print("\nFinishing up...")
- def pytest_configure(config):
- """Register the 'langsmith' marker."""
- config.addinivalue_line(
- "markers", "langsmith: mark test to be tracked in LangSmith"
- )
- if config.getoption("--langsmith-output"):
- if not importlib.util.find_spec("rich"):
- msg = (
- "Must have 'rich' installed to use --langsmith-output. "
- "Please install with: `pip install -U 'langsmith[pytest]'`"
- )
- raise ValueError(msg)
- if os.environ.get("PYTEST_XDIST_TESTRUNUID"):
- msg = (
- "--langsmith-output not supported with pytest-xdist. "
- "Please remove the '--langsmith-output' option or '-n' option."
- )
- raise ValueError(msg)
- if ls_utils.test_tracking_is_disabled():
- msg = (
- "--langsmith-output not supported when env var"
- "LANGSMITH_TEST_TRACKING='false'. Please remove the"
- "'--langsmith-output' option "
- "or enable test tracking."
- )
- raise ValueError(msg)
- config.pluginmanager.register(LangSmithPlugin(), "langsmith_output_plugin")
- # Suppress warnings summary
- config.option.showwarnings = False
- def _abbreviate(x: str, max_len: int) -> str:
- if len(x) > max_len:
- return x[: max_len - 3] + "..."
- else:
- return x
- def _abbreviate_test_name(test_name: str, max_len: int) -> str:
- if len(test_name) > max_len:
- file, test = test_name.split("::")
- if len(".py::" + test) > max_len:
- return "..." + test[-(max_len - 3) :]
- file_len = max_len - len("...::" + test)
- return "..." + file[-file_len:] + "::" + test
- else:
- return test_name
- def _merge_statuses(update: dict, current: dict, *, unpack: list[str]) -> dict:
- for path in unpack:
- if path_update := update.pop(path, None):
- path_current = current.get(path, {})
- if isinstance(path_update, dict) and isinstance(path_current, dict):
- current[path] = {**path_current, **path_update}
- else:
- current[path] = path_update
- return {**current, **update}
- def _dumps_with_fallback(obj: Any) -> str:
- try:
- return json.dumps(obj)
- except Exception:
- return "unserializable"
|