exceptions.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import typing
  2. from starlette._utils import is_async_callable
  3. from starlette.concurrency import run_in_threadpool
  4. from starlette.exceptions import HTTPException, WebSocketException
  5. from starlette.requests import Request
  6. from starlette.responses import PlainTextResponse, Response
  7. from starlette.types import ASGIApp, Message, Receive, Scope, Send
  8. from starlette.websockets import WebSocket
  9. class ExceptionMiddleware:
  10. def __init__(
  11. self,
  12. app: ASGIApp,
  13. handlers: typing.Optional[
  14. typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
  15. ] = None,
  16. debug: bool = False,
  17. ) -> None:
  18. self.app = app
  19. self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
  20. self._status_handlers: typing.Dict[int, typing.Callable] = {}
  21. self._exception_handlers: typing.Dict[
  22. typing.Type[Exception], typing.Callable
  23. ] = {
  24. HTTPException: self.http_exception,
  25. WebSocketException: self.websocket_exception,
  26. }
  27. if handlers is not None:
  28. for key, value in handlers.items():
  29. self.add_exception_handler(key, value)
  30. def add_exception_handler(
  31. self,
  32. exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
  33. handler: typing.Callable[[Request, Exception], Response],
  34. ) -> None:
  35. if isinstance(exc_class_or_status_code, int):
  36. self._status_handlers[exc_class_or_status_code] = handler
  37. else:
  38. assert issubclass(exc_class_or_status_code, Exception)
  39. self._exception_handlers[exc_class_or_status_code] = handler
  40. def _lookup_exception_handler(
  41. self, exc: Exception
  42. ) -> typing.Optional[typing.Callable]:
  43. for cls in type(exc).__mro__:
  44. if cls in self._exception_handlers:
  45. return self._exception_handlers[cls]
  46. return None
  47. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  48. if scope["type"] not in ("http", "websocket"):
  49. await self.app(scope, receive, send)
  50. return
  51. response_started = False
  52. async def sender(message: Message) -> None:
  53. nonlocal response_started
  54. if message["type"] == "http.response.start":
  55. response_started = True
  56. await send(message)
  57. try:
  58. await self.app(scope, receive, sender)
  59. except Exception as exc:
  60. handler = None
  61. if isinstance(exc, HTTPException):
  62. handler = self._status_handlers.get(exc.status_code)
  63. if handler is None:
  64. handler = self._lookup_exception_handler(exc)
  65. if handler is None:
  66. raise exc
  67. if response_started:
  68. msg = "Caught handled exception, but response already started."
  69. raise RuntimeError(msg) from exc
  70. if scope["type"] == "http":
  71. request = Request(scope, receive=receive)
  72. if is_async_callable(handler):
  73. response = await handler(request, exc)
  74. else:
  75. response = await run_in_threadpool(handler, request, exc)
  76. await response(scope, receive, sender)
  77. elif scope["type"] == "websocket":
  78. websocket = WebSocket(scope, receive=receive, send=send)
  79. if is_async_callable(handler):
  80. await handler(websocket, exc)
  81. else:
  82. await run_in_threadpool(handler, websocket, exc)
  83. def http_exception(self, request: Request, exc: HTTPException) -> Response:
  84. if exc.status_code in {204, 304}:
  85. return Response(status_code=exc.status_code, headers=exc.headers)
  86. return PlainTextResponse(
  87. exc.detail, status_code=exc.status_code, headers=exc.headers
  88. )
  89. async def websocket_exception(
  90. self, websocket: WebSocket, exc: WebSocketException
  91. ) -> None:
  92. await websocket.close(code=exc.code, reason=exc.reason)