pytest_plugin.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from __future__ import annotations
  2. from contextlib import contextmanager
  3. from inspect import isasyncgenfunction, iscoroutinefunction
  4. from typing import Any, Dict, Generator, Tuple, cast
  5. import pytest
  6. import sniffio
  7. from ._core._eventloop import get_all_backends, get_asynclib
  8. from .abc import TestRunner
  9. _current_runner: TestRunner | None = None
  10. def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]:
  11. if isinstance(backend, str):
  12. return backend, {}
  13. elif isinstance(backend, tuple) and len(backend) == 2:
  14. if isinstance(backend[0], str) and isinstance(backend[1], dict):
  15. return cast(Tuple[str, Dict[str, Any]], backend)
  16. raise TypeError("anyio_backend must be either a string or tuple of (string, dict)")
  17. @contextmanager
  18. def get_runner(
  19. backend_name: str, backend_options: dict[str, Any]
  20. ) -> Generator[TestRunner, object, None]:
  21. global _current_runner
  22. if _current_runner:
  23. yield _current_runner
  24. return
  25. asynclib = get_asynclib(backend_name)
  26. token = None
  27. if sniffio.current_async_library_cvar.get(None) is None:
  28. # Since we're in control of the event loop, we can cache the name of the async library
  29. token = sniffio.current_async_library_cvar.set(backend_name)
  30. try:
  31. backend_options = backend_options or {}
  32. with asynclib.TestRunner(**backend_options) as runner:
  33. _current_runner = runner
  34. yield runner
  35. finally:
  36. _current_runner = None
  37. if token:
  38. sniffio.current_async_library_cvar.reset(token)
  39. def pytest_configure(config: Any) -> None:
  40. config.addinivalue_line(
  41. "markers",
  42. "anyio: mark the (coroutine function) test to be run "
  43. "asynchronously via anyio.",
  44. )
  45. def pytest_fixture_setup(fixturedef: Any, request: Any) -> None:
  46. def wrapper(*args, anyio_backend, **kwargs): # type: ignore[no-untyped-def]
  47. backend_name, backend_options = extract_backend_and_options(anyio_backend)
  48. if has_backend_arg:
  49. kwargs["anyio_backend"] = anyio_backend
  50. with get_runner(backend_name, backend_options) as runner:
  51. if isasyncgenfunction(func):
  52. yield from runner.run_asyncgen_fixture(func, kwargs)
  53. else:
  54. yield runner.run_fixture(func, kwargs)
  55. # Only apply this to coroutine functions and async generator functions in requests that involve
  56. # the anyio_backend fixture
  57. func = fixturedef.func
  58. if isasyncgenfunction(func) or iscoroutinefunction(func):
  59. if "anyio_backend" in request.fixturenames:
  60. has_backend_arg = "anyio_backend" in fixturedef.argnames
  61. fixturedef.func = wrapper
  62. if not has_backend_arg:
  63. fixturedef.argnames += ("anyio_backend",)
  64. @pytest.hookimpl(tryfirst=True)
  65. def pytest_pycollect_makeitem(collector: Any, name: Any, obj: Any) -> None:
  66. if collector.istestfunction(obj, name):
  67. inner_func = obj.hypothesis.inner_test if hasattr(obj, "hypothesis") else obj
  68. if iscoroutinefunction(inner_func):
  69. marker = collector.get_closest_marker("anyio")
  70. own_markers = getattr(obj, "pytestmark", ())
  71. if marker or any(marker.name == "anyio" for marker in own_markers):
  72. pytest.mark.usefixtures("anyio_backend")(obj)
  73. @pytest.hookimpl(tryfirst=True)
  74. def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None:
  75. def run_with_hypothesis(**kwargs: Any) -> None:
  76. with get_runner(backend_name, backend_options) as runner:
  77. runner.run_test(original_func, kwargs)
  78. backend = pyfuncitem.funcargs.get("anyio_backend")
  79. if backend:
  80. backend_name, backend_options = extract_backend_and_options(backend)
  81. if hasattr(pyfuncitem.obj, "hypothesis"):
  82. # Wrap the inner test function unless it's already wrapped
  83. original_func = pyfuncitem.obj.hypothesis.inner_test
  84. if original_func.__qualname__ != run_with_hypothesis.__qualname__:
  85. if iscoroutinefunction(original_func):
  86. pyfuncitem.obj.hypothesis.inner_test = run_with_hypothesis
  87. return None
  88. if iscoroutinefunction(pyfuncitem.obj):
  89. funcargs = pyfuncitem.funcargs
  90. testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
  91. with get_runner(backend_name, backend_options) as runner:
  92. runner.run_test(pyfuncitem.obj, testargs)
  93. return True
  94. return None
  95. @pytest.fixture(params=get_all_backends())
  96. def anyio_backend(request: Any) -> Any:
  97. return request.param
  98. @pytest.fixture
  99. def anyio_backend_name(anyio_backend: Any) -> str:
  100. if isinstance(anyio_backend, str):
  101. return anyio_backend
  102. else:
  103. return anyio_backend[0]
  104. @pytest.fixture
  105. def anyio_backend_options(anyio_backend: Any) -> dict[str, Any]:
  106. if isinstance(anyio_backend, str):
  107. return {}
  108. else:
  109. return anyio_backend[1]