authentication.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import functools
  2. import inspect
  3. import typing
  4. from urllib.parse import urlencode
  5. from starlette._utils import is_async_callable
  6. from starlette.exceptions import HTTPException
  7. from starlette.requests import HTTPConnection, Request
  8. from starlette.responses import RedirectResponse, Response
  9. from starlette.websockets import WebSocket
  10. _CallableType = typing.TypeVar("_CallableType", bound=typing.Callable)
  11. def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
  12. for scope in scopes:
  13. if scope not in conn.auth.scopes:
  14. return False
  15. return True
  16. def requires(
  17. scopes: typing.Union[str, typing.Sequence[str]],
  18. status_code: int = 403,
  19. redirect: typing.Optional[str] = None,
  20. ) -> typing.Callable[[_CallableType], _CallableType]:
  21. scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
  22. def decorator(func: typing.Callable) -> typing.Callable:
  23. sig = inspect.signature(func)
  24. for idx, parameter in enumerate(sig.parameters.values()):
  25. if parameter.name == "request" or parameter.name == "websocket":
  26. type_ = parameter.name
  27. break
  28. else:
  29. raise Exception(
  30. f'No "request" or "websocket" argument on function "{func}"'
  31. )
  32. if type_ == "websocket":
  33. # Handle websocket functions. (Always async)
  34. @functools.wraps(func)
  35. async def websocket_wrapper(
  36. *args: typing.Any, **kwargs: typing.Any
  37. ) -> None:
  38. websocket = kwargs.get(
  39. "websocket", args[idx] if idx < len(args) else None
  40. )
  41. assert isinstance(websocket, WebSocket)
  42. if not has_required_scope(websocket, scopes_list):
  43. await websocket.close()
  44. else:
  45. await func(*args, **kwargs)
  46. return websocket_wrapper
  47. elif is_async_callable(func):
  48. # Handle async request/response functions.
  49. @functools.wraps(func)
  50. async def async_wrapper(
  51. *args: typing.Any, **kwargs: typing.Any
  52. ) -> Response:
  53. request = kwargs.get("request", args[idx] if idx < len(args) else None)
  54. assert isinstance(request, Request)
  55. if not has_required_scope(request, scopes_list):
  56. if redirect is not None:
  57. orig_request_qparam = urlencode({"next": str(request.url)})
  58. next_url = "{redirect_path}?{orig_request}".format(
  59. redirect_path=request.url_for(redirect),
  60. orig_request=orig_request_qparam,
  61. )
  62. return RedirectResponse(url=next_url, status_code=303)
  63. raise HTTPException(status_code=status_code)
  64. return await func(*args, **kwargs)
  65. return async_wrapper
  66. else:
  67. # Handle sync request/response functions.
  68. @functools.wraps(func)
  69. def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
  70. request = kwargs.get("request", args[idx] if idx < len(args) else None)
  71. assert isinstance(request, Request)
  72. if not has_required_scope(request, scopes_list):
  73. if redirect is not None:
  74. orig_request_qparam = urlencode({"next": str(request.url)})
  75. next_url = "{redirect_path}?{orig_request}".format(
  76. redirect_path=request.url_for(redirect),
  77. orig_request=orig_request_qparam,
  78. )
  79. return RedirectResponse(url=next_url, status_code=303)
  80. raise HTTPException(status_code=status_code)
  81. return func(*args, **kwargs)
  82. return sync_wrapper
  83. return decorator # type: ignore[return-value]
  84. class AuthenticationError(Exception):
  85. pass
  86. class AuthenticationBackend:
  87. async def authenticate(
  88. self, conn: HTTPConnection
  89. ) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
  90. raise NotImplementedError() # pragma: no cover
  91. class AuthCredentials:
  92. def __init__(self, scopes: typing.Optional[typing.Sequence[str]] = None):
  93. self.scopes = [] if scopes is None else list(scopes)
  94. class BaseUser:
  95. @property
  96. def is_authenticated(self) -> bool:
  97. raise NotImplementedError() # pragma: no cover
  98. @property
  99. def display_name(self) -> str:
  100. raise NotImplementedError() # pragma: no cover
  101. @property
  102. def identity(self) -> str:
  103. raise NotImplementedError() # pragma: no cover
  104. class SimpleUser(BaseUser):
  105. def __init__(self, username: str) -> None:
  106. self.username = username
  107. @property
  108. def is_authenticated(self) -> bool:
  109. return True
  110. @property
  111. def display_name(self) -> str:
  112. return self.username
  113. class UnauthenticatedUser(BaseUser):
  114. @property
  115. def is_authenticated(self) -> bool:
  116. return False
  117. @property
  118. def display_name(self) -> str:
  119. return ""