websockets.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import enum
  2. import json
  3. import typing
  4. from starlette.requests import HTTPConnection
  5. from starlette.types import Message, Receive, Scope, Send
  6. class WebSocketState(enum.Enum):
  7. CONNECTING = 0
  8. CONNECTED = 1
  9. DISCONNECTED = 2
  10. class WebSocketDisconnect(Exception):
  11. def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
  12. self.code = code
  13. self.reason = reason or ""
  14. class WebSocket(HTTPConnection):
  15. def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
  16. super().__init__(scope)
  17. assert scope["type"] == "websocket"
  18. self._receive = receive
  19. self._send = send
  20. self.client_state = WebSocketState.CONNECTING
  21. self.application_state = WebSocketState.CONNECTING
  22. async def receive(self) -> Message:
  23. """
  24. Receive ASGI websocket messages, ensuring valid state transitions.
  25. """
  26. if self.client_state == WebSocketState.CONNECTING:
  27. message = await self._receive()
  28. message_type = message["type"]
  29. if message_type != "websocket.connect":
  30. raise RuntimeError(
  31. 'Expected ASGI message "websocket.connect", '
  32. f"but got {message_type!r}"
  33. )
  34. self.client_state = WebSocketState.CONNECTED
  35. return message
  36. elif self.client_state == WebSocketState.CONNECTED:
  37. message = await self._receive()
  38. message_type = message["type"]
  39. if message_type not in {"websocket.receive", "websocket.disconnect"}:
  40. raise RuntimeError(
  41. 'Expected ASGI message "websocket.receive" or '
  42. f'"websocket.disconnect", but got {message_type!r}'
  43. )
  44. if message_type == "websocket.disconnect":
  45. self.client_state = WebSocketState.DISCONNECTED
  46. return message
  47. else:
  48. raise RuntimeError(
  49. 'Cannot call "receive" once a disconnect message has been received.'
  50. )
  51. async def send(self, message: Message) -> None:
  52. """
  53. Send ASGI websocket messages, ensuring valid state transitions.
  54. """
  55. if self.application_state == WebSocketState.CONNECTING:
  56. message_type = message["type"]
  57. if message_type not in {"websocket.accept", "websocket.close"}:
  58. raise RuntimeError(
  59. 'Expected ASGI message "websocket.accept" or '
  60. f'"websocket.close", but got {message_type!r}'
  61. )
  62. if message_type == "websocket.close":
  63. self.application_state = WebSocketState.DISCONNECTED
  64. else:
  65. self.application_state = WebSocketState.CONNECTED
  66. await self._send(message)
  67. elif self.application_state == WebSocketState.CONNECTED:
  68. message_type = message["type"]
  69. if message_type not in {"websocket.send", "websocket.close"}:
  70. raise RuntimeError(
  71. 'Expected ASGI message "websocket.send" or "websocket.close", '
  72. f"but got {message_type!r}"
  73. )
  74. if message_type == "websocket.close":
  75. self.application_state = WebSocketState.DISCONNECTED
  76. await self._send(message)
  77. else:
  78. raise RuntimeError('Cannot call "send" once a close message has been sent.')
  79. async def accept(
  80. self,
  81. subprotocol: typing.Optional[str] = None,
  82. headers: typing.Optional[typing.Iterable[typing.Tuple[bytes, bytes]]] = None,
  83. ) -> None:
  84. headers = headers or []
  85. if self.client_state == WebSocketState.CONNECTING:
  86. # If we haven't yet seen the 'connect' message, then wait for it first.
  87. await self.receive()
  88. await self.send(
  89. {"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
  90. )
  91. def _raise_on_disconnect(self, message: Message) -> None:
  92. if message["type"] == "websocket.disconnect":
  93. raise WebSocketDisconnect(message["code"])
  94. async def receive_text(self) -> str:
  95. if self.application_state != WebSocketState.CONNECTED:
  96. raise RuntimeError(
  97. 'WebSocket is not connected. Need to call "accept" first.'
  98. )
  99. message = await self.receive()
  100. self._raise_on_disconnect(message)
  101. return message["text"]
  102. async def receive_bytes(self) -> bytes:
  103. if self.application_state != WebSocketState.CONNECTED:
  104. raise RuntimeError(
  105. 'WebSocket is not connected. Need to call "accept" first.'
  106. )
  107. message = await self.receive()
  108. self._raise_on_disconnect(message)
  109. return message["bytes"]
  110. async def receive_json(self, mode: str = "text") -> typing.Any:
  111. if mode not in {"text", "binary"}:
  112. raise RuntimeError('The "mode" argument should be "text" or "binary".')
  113. if self.application_state != WebSocketState.CONNECTED:
  114. raise RuntimeError(
  115. 'WebSocket is not connected. Need to call "accept" first.'
  116. )
  117. message = await self.receive()
  118. self._raise_on_disconnect(message)
  119. if mode == "text":
  120. text = message["text"]
  121. else:
  122. text = message["bytes"].decode("utf-8")
  123. return json.loads(text)
  124. async def iter_text(self) -> typing.AsyncIterator[str]:
  125. try:
  126. while True:
  127. yield await self.receive_text()
  128. except WebSocketDisconnect:
  129. pass
  130. async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
  131. try:
  132. while True:
  133. yield await self.receive_bytes()
  134. except WebSocketDisconnect:
  135. pass
  136. async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
  137. try:
  138. while True:
  139. yield await self.receive_json()
  140. except WebSocketDisconnect:
  141. pass
  142. async def send_text(self, data: str) -> None:
  143. await self.send({"type": "websocket.send", "text": data})
  144. async def send_bytes(self, data: bytes) -> None:
  145. await self.send({"type": "websocket.send", "bytes": data})
  146. async def send_json(self, data: typing.Any, mode: str = "text") -> None:
  147. if mode not in {"text", "binary"}:
  148. raise RuntimeError('The "mode" argument should be "text" or "binary".')
  149. text = json.dumps(data, separators=(",", ":"))
  150. if mode == "text":
  151. await self.send({"type": "websocket.send", "text": text})
  152. else:
  153. await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
  154. async def close(
  155. self, code: int = 1000, reason: typing.Optional[str] = None
  156. ) -> None:
  157. await self.send(
  158. {"type": "websocket.close", "code": code, "reason": reason or ""}
  159. )
  160. class WebSocketClose:
  161. def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
  162. self.code = code
  163. self.reason = reason or ""
  164. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  165. await send(
  166. {"type": "websocket.close", "code": self.code, "reason": self.reason}
  167. )