testclient.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797
  1. import contextlib
  2. import inspect
  3. import io
  4. import json
  5. import math
  6. import queue
  7. import sys
  8. import typing
  9. import warnings
  10. from concurrent.futures import Future
  11. from types import GeneratorType
  12. from urllib.parse import unquote, urljoin
  13. import anyio
  14. import anyio.from_thread
  15. import httpx
  16. from anyio.streams.stapled import StapledObjectStream
  17. from starlette._utils import is_async_callable
  18. from starlette.types import ASGIApp, Message, Receive, Scope, Send
  19. from starlette.websockets import WebSocketDisconnect
  20. if sys.version_info >= (3, 8): # pragma: no cover
  21. from typing import TypedDict
  22. else: # pragma: no cover
  23. from typing_extensions import TypedDict
  24. _PortalFactoryType = typing.Callable[
  25. [], typing.ContextManager[anyio.abc.BlockingPortal]
  26. ]
  27. ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
  28. ASGI2App = typing.Callable[[Scope], ASGIInstance]
  29. ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
  30. _RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]]
  31. def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
  32. if inspect.isclass(app):
  33. return hasattr(app, "__await__")
  34. return is_async_callable(app)
  35. class _WrapASGI2:
  36. """
  37. Provide an ASGI3 interface onto an ASGI2 app.
  38. """
  39. def __init__(self, app: ASGI2App) -> None:
  40. self.app = app
  41. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  42. instance = self.app(scope)
  43. await instance(receive, send)
  44. class _AsyncBackend(TypedDict):
  45. backend: str
  46. backend_options: typing.Dict[str, typing.Any]
  47. class _Upgrade(Exception):
  48. def __init__(self, session: "WebSocketTestSession") -> None:
  49. self.session = session
  50. class WebSocketTestSession:
  51. def __init__(
  52. self,
  53. app: ASGI3App,
  54. scope: Scope,
  55. portal_factory: _PortalFactoryType,
  56. ) -> None:
  57. self.app = app
  58. self.scope = scope
  59. self.accepted_subprotocol = None
  60. self.portal_factory = portal_factory
  61. self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
  62. self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
  63. self.extra_headers = None
  64. def __enter__(self) -> "WebSocketTestSession":
  65. self.exit_stack = contextlib.ExitStack()
  66. self.portal = self.exit_stack.enter_context(self.portal_factory())
  67. try:
  68. _: "Future[None]" = self.portal.start_task_soon(self._run)
  69. self.send({"type": "websocket.connect"})
  70. message = self.receive()
  71. self._raise_on_close(message)
  72. except Exception:
  73. self.exit_stack.close()
  74. raise
  75. self.accepted_subprotocol = message.get("subprotocol", None)
  76. self.extra_headers = message.get("headers", None)
  77. return self
  78. def __exit__(self, *args: typing.Any) -> None:
  79. try:
  80. self.close(1000)
  81. finally:
  82. self.exit_stack.close()
  83. while not self._send_queue.empty():
  84. message = self._send_queue.get()
  85. if isinstance(message, BaseException):
  86. raise message
  87. async def _run(self) -> None:
  88. """
  89. The sub-thread in which the websocket session runs.
  90. """
  91. scope = self.scope
  92. receive = self._asgi_receive
  93. send = self._asgi_send
  94. try:
  95. await self.app(scope, receive, send)
  96. except BaseException as exc:
  97. self._send_queue.put(exc)
  98. raise
  99. async def _asgi_receive(self) -> Message:
  100. while self._receive_queue.empty():
  101. await anyio.sleep(0)
  102. return self._receive_queue.get()
  103. async def _asgi_send(self, message: Message) -> None:
  104. self._send_queue.put(message)
  105. def _raise_on_close(self, message: Message) -> None:
  106. if message["type"] == "websocket.close":
  107. raise WebSocketDisconnect(
  108. message.get("code", 1000), message.get("reason", "")
  109. )
  110. def send(self, message: Message) -> None:
  111. self._receive_queue.put(message)
  112. def send_text(self, data: str) -> None:
  113. self.send({"type": "websocket.receive", "text": data})
  114. def send_bytes(self, data: bytes) -> None:
  115. self.send({"type": "websocket.receive", "bytes": data})
  116. def send_json(self, data: typing.Any, mode: str = "text") -> None:
  117. assert mode in ["text", "binary"]
  118. text = json.dumps(data, separators=(",", ":"))
  119. if mode == "text":
  120. self.send({"type": "websocket.receive", "text": text})
  121. else:
  122. self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
  123. def close(self, code: int = 1000) -> None:
  124. self.send({"type": "websocket.disconnect", "code": code})
  125. def receive(self) -> Message:
  126. message = self._send_queue.get()
  127. if isinstance(message, BaseException):
  128. raise message
  129. return message
  130. def receive_text(self) -> str:
  131. message = self.receive()
  132. self._raise_on_close(message)
  133. return message["text"]
  134. def receive_bytes(self) -> bytes:
  135. message = self.receive()
  136. self._raise_on_close(message)
  137. return message["bytes"]
  138. def receive_json(self, mode: str = "text") -> typing.Any:
  139. assert mode in ["text", "binary"]
  140. message = self.receive()
  141. self._raise_on_close(message)
  142. if mode == "text":
  143. text = message["text"]
  144. else:
  145. text = message["bytes"].decode("utf-8")
  146. return json.loads(text)
  147. class _TestClientTransport(httpx.BaseTransport):
  148. def __init__(
  149. self,
  150. app: ASGI3App,
  151. portal_factory: _PortalFactoryType,
  152. raise_server_exceptions: bool = True,
  153. root_path: str = "",
  154. *,
  155. app_state: typing.Dict[str, typing.Any],
  156. ) -> None:
  157. self.app = app
  158. self.raise_server_exceptions = raise_server_exceptions
  159. self.root_path = root_path
  160. self.portal_factory = portal_factory
  161. self.app_state = app_state
  162. def handle_request(self, request: httpx.Request) -> httpx.Response:
  163. scheme = request.url.scheme
  164. netloc = request.url.netloc.decode(encoding="ascii")
  165. path = request.url.path
  166. raw_path = request.url.raw_path
  167. query = request.url.query.decode(encoding="ascii")
  168. default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
  169. if ":" in netloc:
  170. host, port_string = netloc.split(":", 1)
  171. port = int(port_string)
  172. else:
  173. host = netloc
  174. port = default_port
  175. # Include the 'host' header.
  176. if "host" in request.headers:
  177. headers: typing.List[typing.Tuple[bytes, bytes]] = []
  178. elif port == default_port: # pragma: no cover
  179. headers = [(b"host", host.encode())]
  180. else: # pragma: no cover
  181. headers = [(b"host", (f"{host}:{port}").encode())]
  182. # Include other request headers.
  183. headers += [
  184. (key.lower().encode(), value.encode())
  185. for key, value in request.headers.items()
  186. ]
  187. scope: typing.Dict[str, typing.Any]
  188. if scheme in {"ws", "wss"}:
  189. subprotocol = request.headers.get("sec-websocket-protocol", None)
  190. if subprotocol is None:
  191. subprotocols: typing.Sequence[str] = []
  192. else:
  193. subprotocols = [value.strip() for value in subprotocol.split(",")]
  194. scope = {
  195. "type": "websocket",
  196. "path": unquote(path),
  197. "raw_path": raw_path,
  198. "root_path": self.root_path,
  199. "scheme": scheme,
  200. "query_string": query.encode(),
  201. "headers": headers,
  202. "client": ["testclient", 50000],
  203. "server": [host, port],
  204. "subprotocols": subprotocols,
  205. "state": self.app_state.copy(),
  206. }
  207. session = WebSocketTestSession(self.app, scope, self.portal_factory)
  208. raise _Upgrade(session)
  209. scope = {
  210. "type": "http",
  211. "http_version": "1.1",
  212. "method": request.method,
  213. "path": unquote(path),
  214. "raw_path": raw_path,
  215. "root_path": self.root_path,
  216. "scheme": scheme,
  217. "query_string": query.encode(),
  218. "headers": headers,
  219. "client": ["testclient", 50000],
  220. "server": [host, port],
  221. "extensions": {"http.response.debug": {}},
  222. "state": self.app_state.copy(),
  223. }
  224. request_complete = False
  225. response_started = False
  226. response_complete: anyio.Event
  227. raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()}
  228. template = None
  229. context = None
  230. async def receive() -> Message:
  231. nonlocal request_complete
  232. if request_complete:
  233. if not response_complete.is_set():
  234. await response_complete.wait()
  235. return {"type": "http.disconnect"}
  236. body = request.read()
  237. if isinstance(body, str):
  238. body_bytes: bytes = body.encode("utf-8") # pragma: no cover
  239. elif body is None:
  240. body_bytes = b"" # pragma: no cover
  241. elif isinstance(body, GeneratorType):
  242. try: # pragma: no cover
  243. chunk = body.send(None)
  244. if isinstance(chunk, str):
  245. chunk = chunk.encode("utf-8")
  246. return {"type": "http.request", "body": chunk, "more_body": True}
  247. except StopIteration: # pragma: no cover
  248. request_complete = True
  249. return {"type": "http.request", "body": b""}
  250. else:
  251. body_bytes = body
  252. request_complete = True
  253. return {"type": "http.request", "body": body_bytes}
  254. async def send(message: Message) -> None:
  255. nonlocal raw_kwargs, response_started, template, context
  256. if message["type"] == "http.response.start":
  257. assert (
  258. not response_started
  259. ), 'Received multiple "http.response.start" messages.'
  260. raw_kwargs["status_code"] = message["status"]
  261. raw_kwargs["headers"] = [
  262. (key.decode(), value.decode())
  263. for key, value in message.get("headers", [])
  264. ]
  265. response_started = True
  266. elif message["type"] == "http.response.body":
  267. assert (
  268. response_started
  269. ), 'Received "http.response.body" without "http.response.start".'
  270. assert (
  271. not response_complete.is_set()
  272. ), 'Received "http.response.body" after response completed.'
  273. body = message.get("body", b"")
  274. more_body = message.get("more_body", False)
  275. if request.method != "HEAD":
  276. raw_kwargs["stream"].write(body)
  277. if not more_body:
  278. raw_kwargs["stream"].seek(0)
  279. response_complete.set()
  280. elif message["type"] == "http.response.debug":
  281. template = message["info"]["template"]
  282. context = message["info"]["context"]
  283. try:
  284. with self.portal_factory() as portal:
  285. response_complete = portal.call(anyio.Event)
  286. portal.call(self.app, scope, receive, send)
  287. except BaseException as exc:
  288. if self.raise_server_exceptions:
  289. raise exc
  290. if self.raise_server_exceptions:
  291. assert response_started, "TestClient did not receive any response."
  292. elif not response_started:
  293. raw_kwargs = {
  294. "status_code": 500,
  295. "headers": [],
  296. "stream": io.BytesIO(),
  297. }
  298. raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
  299. response = httpx.Response(**raw_kwargs, request=request)
  300. if template is not None:
  301. response.template = template # type: ignore[attr-defined]
  302. response.context = context # type: ignore[attr-defined]
  303. return response
  304. class TestClient(httpx.Client):
  305. __test__ = False
  306. task: "Future[None]"
  307. portal: typing.Optional[anyio.abc.BlockingPortal] = None
  308. def __init__(
  309. self,
  310. app: ASGIApp,
  311. base_url: str = "http://testserver",
  312. raise_server_exceptions: bool = True,
  313. root_path: str = "",
  314. backend: str = "asyncio",
  315. backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
  316. cookies: httpx._client.CookieTypes = None,
  317. headers: typing.Dict[str, str] = None,
  318. ) -> None:
  319. self.async_backend = _AsyncBackend(
  320. backend=backend, backend_options=backend_options or {}
  321. )
  322. if _is_asgi3(app):
  323. app = typing.cast(ASGI3App, app)
  324. asgi_app = app
  325. else:
  326. app = typing.cast(ASGI2App, app) # type: ignore[assignment]
  327. asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
  328. self.app = asgi_app
  329. self.app_state: typing.Dict[str, typing.Any] = {}
  330. transport = _TestClientTransport(
  331. self.app,
  332. portal_factory=self._portal_factory,
  333. raise_server_exceptions=raise_server_exceptions,
  334. root_path=root_path,
  335. app_state=self.app_state,
  336. )
  337. if headers is None:
  338. headers = {}
  339. headers.setdefault("user-agent", "testclient")
  340. super().__init__(
  341. app=self.app,
  342. base_url=base_url,
  343. headers=headers,
  344. transport=transport,
  345. follow_redirects=True,
  346. cookies=cookies,
  347. )
  348. @contextlib.contextmanager
  349. def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
  350. if self.portal is not None:
  351. yield self.portal
  352. else:
  353. with anyio.from_thread.start_blocking_portal(
  354. **self.async_backend
  355. ) as portal:
  356. yield portal
  357. def _choose_redirect_arg(
  358. self,
  359. follow_redirects: typing.Optional[bool],
  360. allow_redirects: typing.Optional[bool],
  361. ) -> typing.Union[bool, httpx._client.UseClientDefault]:
  362. redirect: typing.Union[
  363. bool, httpx._client.UseClientDefault
  364. ] = httpx._client.USE_CLIENT_DEFAULT
  365. if allow_redirects is not None:
  366. message = (
  367. "The `allow_redirects` argument is deprecated. "
  368. "Use `follow_redirects` instead."
  369. )
  370. warnings.warn(message, DeprecationWarning)
  371. redirect = allow_redirects
  372. if follow_redirects is not None:
  373. redirect = follow_redirects
  374. elif allow_redirects is not None and follow_redirects is not None:
  375. raise RuntimeError( # pragma: no cover
  376. "Cannot use both `allow_redirects` and `follow_redirects`."
  377. )
  378. return redirect
  379. def request( # type: ignore[override]
  380. self,
  381. method: str,
  382. url: httpx._types.URLTypes,
  383. *,
  384. content: typing.Optional[httpx._types.RequestContent] = None,
  385. data: typing.Optional[_RequestData] = None,
  386. files: typing.Optional[httpx._types.RequestFiles] = None,
  387. json: typing.Any = None,
  388. params: typing.Optional[httpx._types.QueryParamTypes] = None,
  389. headers: typing.Optional[httpx._types.HeaderTypes] = None,
  390. cookies: typing.Optional[httpx._types.CookieTypes] = None,
  391. auth: typing.Union[
  392. httpx._types.AuthTypes, httpx._client.UseClientDefault
  393. ] = httpx._client.USE_CLIENT_DEFAULT,
  394. follow_redirects: typing.Optional[bool] = None,
  395. allow_redirects: typing.Optional[bool] = None,
  396. timeout: typing.Union[
  397. httpx._client.TimeoutTypes, httpx._client.UseClientDefault
  398. ] = httpx._client.USE_CLIENT_DEFAULT,
  399. extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
  400. ) -> httpx.Response:
  401. url = self.base_url.join(url)
  402. redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
  403. return super().request(
  404. method,
  405. url,
  406. content=content,
  407. data=data, # type: ignore[arg-type]
  408. files=files,
  409. json=json,
  410. params=params,
  411. headers=headers,
  412. cookies=cookies,
  413. auth=auth,
  414. follow_redirects=redirect,
  415. timeout=timeout,
  416. extensions=extensions,
  417. )
  418. def get( # type: ignore[override]
  419. self,
  420. url: httpx._types.URLTypes,
  421. *,
  422. params: typing.Optional[httpx._types.QueryParamTypes] = None,
  423. headers: typing.Optional[httpx._types.HeaderTypes] = None,
  424. cookies: typing.Optional[httpx._types.CookieTypes] = None,
  425. auth: typing.Union[
  426. httpx._types.AuthTypes, httpx._client.UseClientDefault
  427. ] = httpx._client.USE_CLIENT_DEFAULT,
  428. follow_redirects: typing.Optional[bool] = None,
  429. allow_redirects: typing.Optional[bool] = None,
  430. timeout: typing.Union[
  431. httpx._client.TimeoutTypes, httpx._client.UseClientDefault
  432. ] = httpx._client.USE_CLIENT_DEFAULT,
  433. extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
  434. ) -> httpx.Response:
  435. redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
  436. return super().get(
  437. url,
  438. params=params,
  439. headers=headers,
  440. cookies=cookies,
  441. auth=auth,
  442. follow_redirects=redirect,
  443. timeout=timeout,
  444. extensions=extensions,
  445. )
  446. def options( # type: ignore[override]
  447. self,
  448. url: httpx._types.URLTypes,
  449. *,
  450. params: typing.Optional[httpx._types.QueryParamTypes] = None,
  451. headers: typing.Optional[httpx._types.HeaderTypes] = None,
  452. cookies: typing.Optional[httpx._types.CookieTypes] = None,
  453. auth: typing.Union[
  454. httpx._types.AuthTypes, httpx._client.UseClientDefault
  455. ] = httpx._client.USE_CLIENT_DEFAULT,
  456. follow_redirects: typing.Optional[bool] = None,
  457. allow_redirects: typing.Optional[bool] = None,
  458. timeout: typing.Union[
  459. httpx._client.TimeoutTypes, httpx._client.UseClientDefault
  460. ] = httpx._client.USE_CLIENT_DEFAULT,
  461. extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
  462. ) -> httpx.Response:
  463. redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
  464. return super().options(
  465. url,
  466. params=params,
  467. headers=headers,
  468. cookies=cookies,
  469. auth=auth,
  470. follow_redirects=redirect,
  471. timeout=timeout,
  472. extensions=extensions,
  473. )
  474. def head( # type: ignore[override]
  475. self,
  476. url: httpx._types.URLTypes,
  477. *,
  478. params: typing.Optional[httpx._types.QueryParamTypes] = None,
  479. headers: typing.Optional[httpx._types.HeaderTypes] = None,
  480. cookies: typing.Optional[httpx._types.CookieTypes] = None,
  481. auth: typing.Union[
  482. httpx._types.AuthTypes, httpx._client.UseClientDefault
  483. ] = httpx._client.USE_CLIENT_DEFAULT,
  484. follow_redirects: typing.Optional[bool] = None,
  485. allow_redirects: typing.Optional[bool] = None,
  486. timeout: typing.Union[
  487. httpx._client.TimeoutTypes, httpx._client.UseClientDefault
  488. ] = httpx._client.USE_CLIENT_DEFAULT,
  489. extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
  490. ) -> httpx.Response:
  491. redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
  492. return super().head(
  493. url,
  494. params=params,
  495. headers=headers,
  496. cookies=cookies,
  497. auth=auth,
  498. follow_redirects=redirect,
  499. timeout=timeout,
  500. extensions=extensions,
  501. )
  502. def post( # type: ignore[override]
  503. self,
  504. url: httpx._types.URLTypes,
  505. *,
  506. content: typing.Optional[httpx._types.RequestContent] = None,
  507. data: typing.Optional[_RequestData] = None,
  508. files: typing.Optional[httpx._types.RequestFiles] = None,
  509. json: typing.Any = None,
  510. params: typing.Optional[httpx._types.QueryParamTypes] = None,
  511. headers: typing.Optional[httpx._types.HeaderTypes] = None,
  512. cookies: typing.Optional[httpx._types.CookieTypes] = None,
  513. auth: typing.Union[
  514. httpx._types.AuthTypes, httpx._client.UseClientDefault
  515. ] = httpx._client.USE_CLIENT_DEFAULT,
  516. follow_redirects: typing.Optional[bool] = None,
  517. allow_redirects: typing.Optional[bool] = None,
  518. timeout: typing.Union[
  519. httpx._client.TimeoutTypes, httpx._client.UseClientDefault
  520. ] = httpx._client.USE_CLIENT_DEFAULT,
  521. extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
  522. ) -> httpx.Response:
  523. redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
  524. return super().post(
  525. url,
  526. content=content,
  527. data=data, # type: ignore[arg-type]
  528. files=files,
  529. json=json,
  530. params=params,
  531. headers=headers,
  532. cookies=cookies,
  533. auth=auth,
  534. follow_redirects=redirect,
  535. timeout=timeout,
  536. extensions=extensions,
  537. )
  538. def put( # type: ignore[override]
  539. self,
  540. url: httpx._types.URLTypes,
  541. *,
  542. content: typing.Optional[httpx._types.RequestContent] = None,
  543. data: typing.Optional[_RequestData] = None,
  544. files: typing.Optional[httpx._types.RequestFiles] = None,
  545. json: typing.Any = None,
  546. params: typing.Optional[httpx._types.QueryParamTypes] = None,
  547. headers: typing.Optional[httpx._types.HeaderTypes] = None,
  548. cookies: typing.Optional[httpx._types.CookieTypes] = None,
  549. auth: typing.Union[
  550. httpx._types.AuthTypes, httpx._client.UseClientDefault
  551. ] = httpx._client.USE_CLIENT_DEFAULT,
  552. follow_redirects: typing.Optional[bool] = None,
  553. allow_redirects: typing.Optional[bool] = None,
  554. timeout: typing.Union[
  555. httpx._client.TimeoutTypes, httpx._client.UseClientDefault
  556. ] = httpx._client.USE_CLIENT_DEFAULT,
  557. extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
  558. ) -> httpx.Response:
  559. redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
  560. return super().put(
  561. url,
  562. content=content,
  563. data=data, # type: ignore[arg-type]
  564. files=files,
  565. json=json,
  566. params=params,
  567. headers=headers,
  568. cookies=cookies,
  569. auth=auth,
  570. follow_redirects=redirect,
  571. timeout=timeout,
  572. extensions=extensions,
  573. )
  574. def patch( # type: ignore[override]
  575. self,
  576. url: httpx._types.URLTypes,
  577. *,
  578. content: typing.Optional[httpx._types.RequestContent] = None,
  579. data: typing.Optional[_RequestData] = None,
  580. files: typing.Optional[httpx._types.RequestFiles] = None,
  581. json: typing.Any = None,
  582. params: typing.Optional[httpx._types.QueryParamTypes] = None,
  583. headers: typing.Optional[httpx._types.HeaderTypes] = None,
  584. cookies: typing.Optional[httpx._types.CookieTypes] = None,
  585. auth: typing.Union[
  586. httpx._types.AuthTypes, httpx._client.UseClientDefault
  587. ] = httpx._client.USE_CLIENT_DEFAULT,
  588. follow_redirects: typing.Optional[bool] = None,
  589. allow_redirects: typing.Optional[bool] = None,
  590. timeout: typing.Union[
  591. httpx._client.TimeoutTypes, httpx._client.UseClientDefault
  592. ] = httpx._client.USE_CLIENT_DEFAULT,
  593. extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
  594. ) -> httpx.Response:
  595. redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
  596. return super().patch(
  597. url,
  598. content=content,
  599. data=data, # type: ignore[arg-type]
  600. files=files,
  601. json=json,
  602. params=params,
  603. headers=headers,
  604. cookies=cookies,
  605. auth=auth,
  606. follow_redirects=redirect,
  607. timeout=timeout,
  608. extensions=extensions,
  609. )
  610. def delete( # type: ignore[override]
  611. self,
  612. url: httpx._types.URLTypes,
  613. *,
  614. params: typing.Optional[httpx._types.QueryParamTypes] = None,
  615. headers: typing.Optional[httpx._types.HeaderTypes] = None,
  616. cookies: typing.Optional[httpx._types.CookieTypes] = None,
  617. auth: typing.Union[
  618. httpx._types.AuthTypes, httpx._client.UseClientDefault
  619. ] = httpx._client.USE_CLIENT_DEFAULT,
  620. follow_redirects: typing.Optional[bool] = None,
  621. allow_redirects: typing.Optional[bool] = None,
  622. timeout: typing.Union[
  623. httpx._client.TimeoutTypes, httpx._client.UseClientDefault
  624. ] = httpx._client.USE_CLIENT_DEFAULT,
  625. extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
  626. ) -> httpx.Response:
  627. redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
  628. return super().delete(
  629. url,
  630. params=params,
  631. headers=headers,
  632. cookies=cookies,
  633. auth=auth,
  634. follow_redirects=redirect,
  635. timeout=timeout,
  636. extensions=extensions,
  637. )
  638. def websocket_connect(
  639. self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
  640. ) -> typing.Any:
  641. url = urljoin("ws://testserver", url)
  642. headers = kwargs.get("headers", {})
  643. headers.setdefault("connection", "upgrade")
  644. headers.setdefault("sec-websocket-key", "testserver==")
  645. headers.setdefault("sec-websocket-version", "13")
  646. if subprotocols is not None:
  647. headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
  648. kwargs["headers"] = headers
  649. try:
  650. super().request("GET", url, **kwargs)
  651. except _Upgrade as exc:
  652. session = exc.session
  653. else:
  654. raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
  655. return session
  656. def __enter__(self) -> "TestClient":
  657. with contextlib.ExitStack() as stack:
  658. self.portal = portal = stack.enter_context(
  659. anyio.from_thread.start_blocking_portal(**self.async_backend)
  660. )
  661. @stack.callback
  662. def reset_portal() -> None:
  663. self.portal = None
  664. self.stream_send = StapledObjectStream(
  665. *anyio.create_memory_object_stream(math.inf)
  666. )
  667. self.stream_receive = StapledObjectStream(
  668. *anyio.create_memory_object_stream(math.inf)
  669. )
  670. self.task = portal.start_task_soon(self.lifespan)
  671. portal.call(self.wait_startup)
  672. @stack.callback
  673. def wait_shutdown() -> None:
  674. portal.call(self.wait_shutdown)
  675. self.exit_stack = stack.pop_all()
  676. return self
  677. def __exit__(self, *args: typing.Any) -> None:
  678. self.exit_stack.close()
  679. async def lifespan(self) -> None:
  680. scope = {"type": "lifespan", "state": self.app_state}
  681. try:
  682. await self.app(scope, self.stream_receive.receive, self.stream_send.send)
  683. finally:
  684. await self.stream_send.send(None)
  685. async def wait_startup(self) -> None:
  686. await self.stream_receive.send({"type": "lifespan.startup"})
  687. async def receive() -> typing.Any:
  688. message = await self.stream_send.receive()
  689. if message is None:
  690. self.task.result()
  691. return message
  692. message = await receive()
  693. assert message["type"] in (
  694. "lifespan.startup.complete",
  695. "lifespan.startup.failed",
  696. )
  697. if message["type"] == "lifespan.startup.failed":
  698. await receive()
  699. async def wait_shutdown(self) -> None:
  700. async def receive() -> typing.Any:
  701. message = await self.stream_send.receive()
  702. if message is None:
  703. self.task.result()
  704. return message
  705. async with self.stream_send:
  706. await self.stream_receive.send({"type": "lifespan.shutdown"})
  707. message = await receive()
  708. assert message["type"] in (
  709. "lifespan.shutdown.complete",
  710. "lifespan.shutdown.failed",
  711. )
  712. if message["type"] == "lifespan.shutdown.failed":
  713. await receive()