base.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import typing
  2. import anyio
  3. from starlette.background import BackgroundTask
  4. from starlette.requests import Request
  5. from starlette.responses import ContentStream, Response, StreamingResponse
  6. from starlette.types import ASGIApp, Message, Receive, Scope, Send
  7. RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
  8. DispatchFunction = typing.Callable[
  9. [Request, RequestResponseEndpoint], typing.Awaitable[Response]
  10. ]
  11. T = typing.TypeVar("T")
  12. class BaseHTTPMiddleware:
  13. def __init__(
  14. self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
  15. ) -> None:
  16. self.app = app
  17. self.dispatch_func = self.dispatch if dispatch is None else dispatch
  18. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  19. if scope["type"] != "http":
  20. await self.app(scope, receive, send)
  21. return
  22. response_sent = anyio.Event()
  23. async def call_next(request: Request) -> Response:
  24. app_exc: typing.Optional[Exception] = None
  25. send_stream, recv_stream = anyio.create_memory_object_stream()
  26. async def receive_or_disconnect() -> Message:
  27. if response_sent.is_set():
  28. return {"type": "http.disconnect"}
  29. async with anyio.create_task_group() as task_group:
  30. async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
  31. result = await func()
  32. task_group.cancel_scope.cancel()
  33. return result
  34. task_group.start_soon(wrap, response_sent.wait)
  35. message = await wrap(request.receive)
  36. if response_sent.is_set():
  37. return {"type": "http.disconnect"}
  38. return message
  39. async def close_recv_stream_on_response_sent() -> None:
  40. await response_sent.wait()
  41. recv_stream.close()
  42. async def send_no_error(message: Message) -> None:
  43. try:
  44. await send_stream.send(message)
  45. except anyio.BrokenResourceError:
  46. # recv_stream has been closed, i.e. response_sent has been set.
  47. return
  48. async def coro() -> None:
  49. nonlocal app_exc
  50. async with send_stream:
  51. try:
  52. await self.app(scope, receive_or_disconnect, send_no_error)
  53. except Exception as exc:
  54. app_exc = exc
  55. task_group.start_soon(close_recv_stream_on_response_sent)
  56. task_group.start_soon(coro)
  57. try:
  58. message = await recv_stream.receive()
  59. info = message.get("info", None)
  60. if message["type"] == "http.response.debug" and info is not None:
  61. message = await recv_stream.receive()
  62. except anyio.EndOfStream:
  63. if app_exc is not None:
  64. raise app_exc
  65. raise RuntimeError("No response returned.")
  66. assert message["type"] == "http.response.start"
  67. async def body_stream() -> typing.AsyncGenerator[bytes, None]:
  68. async with recv_stream:
  69. async for message in recv_stream:
  70. assert message["type"] == "http.response.body"
  71. body = message.get("body", b"")
  72. if body:
  73. yield body
  74. if app_exc is not None:
  75. raise app_exc
  76. response = _StreamingResponse(
  77. status_code=message["status"], content=body_stream(), info=info
  78. )
  79. response.raw_headers = message["headers"]
  80. return response
  81. async with anyio.create_task_group() as task_group:
  82. request = Request(scope, receive=receive)
  83. response = await self.dispatch_func(request, call_next)
  84. await response(scope, receive, send)
  85. response_sent.set()
  86. async def dispatch(
  87. self, request: Request, call_next: RequestResponseEndpoint
  88. ) -> Response:
  89. raise NotImplementedError() # pragma: no cover
  90. class _StreamingResponse(StreamingResponse):
  91. def __init__(
  92. self,
  93. content: ContentStream,
  94. status_code: int = 200,
  95. headers: typing.Optional[typing.Mapping[str, str]] = None,
  96. media_type: typing.Optional[str] = None,
  97. background: typing.Optional[BackgroundTask] = None,
  98. info: typing.Optional[typing.Mapping[str, typing.Any]] = None,
  99. ) -> None:
  100. self._info = info
  101. super().__init__(content, status_code, headers, media_type, background)
  102. async def stream_response(self, send: Send) -> None:
  103. if self._info:
  104. await send({"type": "http.response.debug", "info": self._info})
  105. return await super().stream_response(send)