sessions.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import json
  2. import sys
  3. import typing
  4. from base64 import b64decode, b64encode
  5. import itsdangerous
  6. from itsdangerous.exc import BadSignature
  7. from starlette.datastructures import MutableHeaders, Secret
  8. from starlette.requests import HTTPConnection
  9. from starlette.types import ASGIApp, Message, Receive, Scope, Send
  10. if sys.version_info >= (3, 8): # pragma: no cover
  11. from typing import Literal
  12. else: # pragma: no cover
  13. from typing_extensions import Literal
  14. class SessionMiddleware:
  15. def __init__(
  16. self,
  17. app: ASGIApp,
  18. secret_key: typing.Union[str, Secret],
  19. session_cookie: str = "session",
  20. max_age: typing.Optional[int] = 14 * 24 * 60 * 60, # 14 days, in seconds
  21. path: str = "/",
  22. same_site: Literal["lax", "strict", "none"] = "lax",
  23. https_only: bool = False,
  24. ) -> None:
  25. self.app = app
  26. self.signer = itsdangerous.TimestampSigner(str(secret_key))
  27. self.session_cookie = session_cookie
  28. self.max_age = max_age
  29. self.path = path
  30. self.security_flags = "httponly; samesite=" + same_site
  31. if https_only: # Secure flag can be used with HTTPS only
  32. self.security_flags += "; secure"
  33. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  34. if scope["type"] not in ("http", "websocket"): # pragma: no cover
  35. await self.app(scope, receive, send)
  36. return
  37. connection = HTTPConnection(scope)
  38. initial_session_was_empty = True
  39. if self.session_cookie in connection.cookies:
  40. data = connection.cookies[self.session_cookie].encode("utf-8")
  41. try:
  42. data = self.signer.unsign(data, max_age=self.max_age)
  43. scope["session"] = json.loads(b64decode(data))
  44. initial_session_was_empty = False
  45. except BadSignature:
  46. scope["session"] = {}
  47. else:
  48. scope["session"] = {}
  49. async def send_wrapper(message: Message) -> None:
  50. if message["type"] == "http.response.start":
  51. if scope["session"]:
  52. # We have session data to persist.
  53. data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
  54. data = self.signer.sign(data)
  55. headers = MutableHeaders(scope=message)
  56. header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501
  57. session_cookie=self.session_cookie,
  58. data=data.decode("utf-8"),
  59. path=self.path,
  60. max_age=f"Max-Age={self.max_age}; " if self.max_age else "",
  61. security_flags=self.security_flags,
  62. )
  63. headers.append("Set-Cookie", header_value)
  64. elif not initial_session_was_empty:
  65. # The session has been cleared.
  66. headers = MutableHeaders(scope=message)
  67. header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( # noqa E501
  68. session_cookie=self.session_cookie,
  69. data="null",
  70. path=self.path,
  71. expires="expires=Thu, 01 Jan 1970 00:00:00 GMT; ",
  72. security_flags=self.security_flags,
  73. )
  74. headers.append("Set-Cookie", header_value)
  75. await send(message)
  76. await self.app(scope, receive, send_wrapper)