| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- import typing
- import anyio
- from starlette.background import BackgroundTask
- from starlette.requests import Request
- from starlette.responses import ContentStream, Response, StreamingResponse
- from starlette.types import ASGIApp, Message, Receive, Scope, Send
- RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
- DispatchFunction = typing.Callable[
- [Request, RequestResponseEndpoint], typing.Awaitable[Response]
- ]
- T = typing.TypeVar("T")
- class BaseHTTPMiddleware:
- def __init__(
- self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
- ) -> None:
- self.app = app
- self.dispatch_func = self.dispatch if dispatch is None else dispatch
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- if scope["type"] != "http":
- await self.app(scope, receive, send)
- return
- response_sent = anyio.Event()
- async def call_next(request: Request) -> Response:
- app_exc: typing.Optional[Exception] = None
- send_stream, recv_stream = anyio.create_memory_object_stream()
- async def receive_or_disconnect() -> Message:
- if response_sent.is_set():
- return {"type": "http.disconnect"}
- async with anyio.create_task_group() as task_group:
- async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T:
- result = await func()
- task_group.cancel_scope.cancel()
- return result
- task_group.start_soon(wrap, response_sent.wait)
- message = await wrap(request.receive)
- if response_sent.is_set():
- return {"type": "http.disconnect"}
- return message
- async def close_recv_stream_on_response_sent() -> None:
- await response_sent.wait()
- recv_stream.close()
- async def send_no_error(message: Message) -> None:
- try:
- await send_stream.send(message)
- except anyio.BrokenResourceError:
- # recv_stream has been closed, i.e. response_sent has been set.
- return
- async def coro() -> None:
- nonlocal app_exc
- async with send_stream:
- try:
- await self.app(scope, receive_or_disconnect, send_no_error)
- except Exception as exc:
- app_exc = exc
- task_group.start_soon(close_recv_stream_on_response_sent)
- task_group.start_soon(coro)
- try:
- message = await recv_stream.receive()
- info = message.get("info", None)
- if message["type"] == "http.response.debug" and info is not None:
- message = await recv_stream.receive()
- except anyio.EndOfStream:
- if app_exc is not None:
- raise app_exc
- raise RuntimeError("No response returned.")
- assert message["type"] == "http.response.start"
- async def body_stream() -> typing.AsyncGenerator[bytes, None]:
- async with recv_stream:
- async for message in recv_stream:
- assert message["type"] == "http.response.body"
- body = message.get("body", b"")
- if body:
- yield body
- if app_exc is not None:
- raise app_exc
- response = _StreamingResponse(
- status_code=message["status"], content=body_stream(), info=info
- )
- response.raw_headers = message["headers"]
- return response
- async with anyio.create_task_group() as task_group:
- request = Request(scope, receive=receive)
- response = await self.dispatch_func(request, call_next)
- await response(scope, receive, send)
- response_sent.set()
- async def dispatch(
- self, request: Request, call_next: RequestResponseEndpoint
- ) -> Response:
- raise NotImplementedError() # pragma: no cover
- class _StreamingResponse(StreamingResponse):
- def __init__(
- self,
- content: ContentStream,
- status_code: int = 200,
- headers: typing.Optional[typing.Mapping[str, str]] = None,
- media_type: typing.Optional[str] = None,
- background: typing.Optional[BackgroundTask] = None,
- info: typing.Optional[typing.Mapping[str, typing.Any]] = None,
- ) -> None:
- self._info = info
- super().__init__(content, status_code, headers, media_type, background)
- async def stream_response(self, send: Send) -> None:
- if self._info:
- await send({"type": "http.response.debug", "info": self._info})
- return await super().stream_response(send)
|