| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- import typing
- from starlette._utils import is_async_callable
- from starlette.concurrency import run_in_threadpool
- from starlette.exceptions import HTTPException, WebSocketException
- from starlette.requests import Request
- from starlette.responses import PlainTextResponse, Response
- from starlette.types import ASGIApp, Message, Receive, Scope, Send
- from starlette.websockets import WebSocket
- class ExceptionMiddleware:
- def __init__(
- self,
- app: ASGIApp,
- handlers: typing.Optional[
- typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
- ] = None,
- debug: bool = False,
- ) -> None:
- self.app = app
- self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
- self._status_handlers: typing.Dict[int, typing.Callable] = {}
- self._exception_handlers: typing.Dict[
- typing.Type[Exception], typing.Callable
- ] = {
- HTTPException: self.http_exception,
- WebSocketException: self.websocket_exception,
- }
- if handlers is not None:
- for key, value in handlers.items():
- self.add_exception_handler(key, value)
- def add_exception_handler(
- self,
- exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
- handler: typing.Callable[[Request, Exception], Response],
- ) -> None:
- if isinstance(exc_class_or_status_code, int):
- self._status_handlers[exc_class_or_status_code] = handler
- else:
- assert issubclass(exc_class_or_status_code, Exception)
- self._exception_handlers[exc_class_or_status_code] = handler
- def _lookup_exception_handler(
- self, exc: Exception
- ) -> typing.Optional[typing.Callable]:
- for cls in type(exc).__mro__:
- if cls in self._exception_handlers:
- return self._exception_handlers[cls]
- return None
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- if scope["type"] not in ("http", "websocket"):
- await self.app(scope, receive, send)
- return
- response_started = False
- async def sender(message: Message) -> None:
- nonlocal response_started
- if message["type"] == "http.response.start":
- response_started = True
- await send(message)
- try:
- await self.app(scope, receive, sender)
- except Exception as exc:
- handler = None
- if isinstance(exc, HTTPException):
- handler = self._status_handlers.get(exc.status_code)
- if handler is None:
- handler = self._lookup_exception_handler(exc)
- if handler is None:
- raise exc
- if response_started:
- msg = "Caught handled exception, but response already started."
- raise RuntimeError(msg) from exc
- if scope["type"] == "http":
- request = Request(scope, receive=receive)
- if is_async_callable(handler):
- response = await handler(request, exc)
- else:
- response = await run_in_threadpool(handler, request, exc)
- await response(scope, receive, sender)
- elif scope["type"] == "websocket":
- websocket = WebSocket(scope, receive=receive, send=send)
- if is_async_callable(handler):
- await handler(websocket, exc)
- else:
- await run_in_threadpool(handler, websocket, exc)
- def http_exception(self, request: Request, exc: HTTPException) -> Response:
- if exc.status_code in {204, 304}:
- return Response(status_code=exc.status_code, headers=exc.headers)
- return PlainTextResponse(
- exc.detail, status_code=exc.status_code, headers=exc.headers
- )
- async def websocket_exception(
- self, websocket: WebSocket, exc: WebSocketException
- ) -> None:
- await websocket.close(code=exc.code, reason=exc.reason)
|