endpoints.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import json
  2. import typing
  3. from starlette import status
  4. from starlette._utils import is_async_callable
  5. from starlette.concurrency import run_in_threadpool
  6. from starlette.exceptions import HTTPException
  7. from starlette.requests import Request
  8. from starlette.responses import PlainTextResponse, Response
  9. from starlette.types import Message, Receive, Scope, Send
  10. from starlette.websockets import WebSocket
  11. class HTTPEndpoint:
  12. def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
  13. assert scope["type"] == "http"
  14. self.scope = scope
  15. self.receive = receive
  16. self.send = send
  17. self._allowed_methods = [
  18. method
  19. for method in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")
  20. if getattr(self, method.lower(), None) is not None
  21. ]
  22. def __await__(self) -> typing.Generator:
  23. return self.dispatch().__await__()
  24. async def dispatch(self) -> None:
  25. request = Request(self.scope, receive=self.receive)
  26. handler_name = (
  27. "get"
  28. if request.method == "HEAD" and not hasattr(self, "head")
  29. else request.method.lower()
  30. )
  31. handler: typing.Callable[[Request], typing.Any] = getattr(
  32. self, handler_name, self.method_not_allowed
  33. )
  34. is_async = is_async_callable(handler)
  35. if is_async:
  36. response = await handler(request)
  37. else:
  38. response = await run_in_threadpool(handler, request)
  39. await response(self.scope, self.receive, self.send)
  40. async def method_not_allowed(self, request: Request) -> Response:
  41. # If we're running inside a starlette application then raise an
  42. # exception, so that the configurable exception handler can deal with
  43. # returning the response. For plain ASGI apps, just return the response.
  44. headers = {"Allow": ", ".join(self._allowed_methods)}
  45. if "app" in self.scope:
  46. raise HTTPException(status_code=405, headers=headers)
  47. return PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
  48. class WebSocketEndpoint:
  49. encoding: typing.Optional[str] = None # May be "text", "bytes", or "json".
  50. def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
  51. assert scope["type"] == "websocket"
  52. self.scope = scope
  53. self.receive = receive
  54. self.send = send
  55. def __await__(self) -> typing.Generator:
  56. return self.dispatch().__await__()
  57. async def dispatch(self) -> None:
  58. websocket = WebSocket(self.scope, receive=self.receive, send=self.send)
  59. await self.on_connect(websocket)
  60. close_code = status.WS_1000_NORMAL_CLOSURE
  61. try:
  62. while True:
  63. message = await websocket.receive()
  64. if message["type"] == "websocket.receive":
  65. data = await self.decode(websocket, message)
  66. await self.on_receive(websocket, data)
  67. elif message["type"] == "websocket.disconnect":
  68. close_code = int(
  69. message.get("code") or status.WS_1000_NORMAL_CLOSURE
  70. )
  71. break
  72. except Exception as exc:
  73. close_code = status.WS_1011_INTERNAL_ERROR
  74. raise exc
  75. finally:
  76. await self.on_disconnect(websocket, close_code)
  77. async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
  78. if self.encoding == "text":
  79. if "text" not in message:
  80. await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
  81. raise RuntimeError("Expected text websocket messages, but got bytes")
  82. return message["text"]
  83. elif self.encoding == "bytes":
  84. if "bytes" not in message:
  85. await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
  86. raise RuntimeError("Expected bytes websocket messages, but got text")
  87. return message["bytes"]
  88. elif self.encoding == "json":
  89. if message.get("text") is not None:
  90. text = message["text"]
  91. else:
  92. text = message["bytes"].decode("utf-8")
  93. try:
  94. return json.loads(text)
  95. except json.decoder.JSONDecodeError:
  96. await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
  97. raise RuntimeError("Malformed JSON data received.")
  98. assert (
  99. self.encoding is None
  100. ), f"Unsupported 'encoding' attribute {self.encoding}"
  101. return message["text"] if message.get("text") else message["bytes"]
  102. async def on_connect(self, websocket: WebSocket) -> None:
  103. """Override to handle an incoming websocket connection"""
  104. await websocket.accept()
  105. async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
  106. """Override to handle an incoming websocket message"""
  107. async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
  108. """Override to handle a disconnecting websocket"""