| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- import functools
- import inspect
- import typing
- from urllib.parse import urlencode
- from starlette._utils import is_async_callable
- from starlette.exceptions import HTTPException
- from starlette.requests import HTTPConnection, Request
- from starlette.responses import RedirectResponse, Response
- from starlette.websockets import WebSocket
- _CallableType = typing.TypeVar("_CallableType", bound=typing.Callable)
- def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
- for scope in scopes:
- if scope not in conn.auth.scopes:
- return False
- return True
- def requires(
- scopes: typing.Union[str, typing.Sequence[str]],
- status_code: int = 403,
- redirect: typing.Optional[str] = None,
- ) -> typing.Callable[[_CallableType], _CallableType]:
- scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
- def decorator(func: typing.Callable) -> typing.Callable:
- sig = inspect.signature(func)
- for idx, parameter in enumerate(sig.parameters.values()):
- if parameter.name == "request" or parameter.name == "websocket":
- type_ = parameter.name
- break
- else:
- raise Exception(
- f'No "request" or "websocket" argument on function "{func}"'
- )
- if type_ == "websocket":
- # Handle websocket functions. (Always async)
- @functools.wraps(func)
- async def websocket_wrapper(
- *args: typing.Any, **kwargs: typing.Any
- ) -> None:
- websocket = kwargs.get(
- "websocket", args[idx] if idx < len(args) else None
- )
- assert isinstance(websocket, WebSocket)
- if not has_required_scope(websocket, scopes_list):
- await websocket.close()
- else:
- await func(*args, **kwargs)
- return websocket_wrapper
- elif is_async_callable(func):
- # Handle async request/response functions.
- @functools.wraps(func)
- async def async_wrapper(
- *args: typing.Any, **kwargs: typing.Any
- ) -> Response:
- request = kwargs.get("request", args[idx] if idx < len(args) else None)
- assert isinstance(request, Request)
- if not has_required_scope(request, scopes_list):
- if redirect is not None:
- orig_request_qparam = urlencode({"next": str(request.url)})
- next_url = "{redirect_path}?{orig_request}".format(
- redirect_path=request.url_for(redirect),
- orig_request=orig_request_qparam,
- )
- return RedirectResponse(url=next_url, status_code=303)
- raise HTTPException(status_code=status_code)
- return await func(*args, **kwargs)
- return async_wrapper
- else:
- # Handle sync request/response functions.
- @functools.wraps(func)
- def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
- request = kwargs.get("request", args[idx] if idx < len(args) else None)
- assert isinstance(request, Request)
- if not has_required_scope(request, scopes_list):
- if redirect is not None:
- orig_request_qparam = urlencode({"next": str(request.url)})
- next_url = "{redirect_path}?{orig_request}".format(
- redirect_path=request.url_for(redirect),
- orig_request=orig_request_qparam,
- )
- return RedirectResponse(url=next_url, status_code=303)
- raise HTTPException(status_code=status_code)
- return func(*args, **kwargs)
- return sync_wrapper
- return decorator # type: ignore[return-value]
- class AuthenticationError(Exception):
- pass
- class AuthenticationBackend:
- async def authenticate(
- self, conn: HTTPConnection
- ) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
- raise NotImplementedError() # pragma: no cover
- class AuthCredentials:
- def __init__(self, scopes: typing.Optional[typing.Sequence[str]] = None):
- self.scopes = [] if scopes is None else list(scopes)
- class BaseUser:
- @property
- def is_authenticated(self) -> bool:
- raise NotImplementedError() # pragma: no cover
- @property
- def display_name(self) -> str:
- raise NotImplementedError() # pragma: no cover
- @property
- def identity(self) -> str:
- raise NotImplementedError() # pragma: no cover
- class SimpleUser(BaseUser):
- def __init__(self, username: str) -> None:
- self.username = username
- @property
- def is_authenticated(self) -> bool:
- return True
- @property
- def display_name(self) -> str:
- return self.username
- class UnauthenticatedUser(BaseUser):
- @property
- def is_authenticated(self) -> bool:
- return False
- @property
- def display_name(self) -> str:
- return ""
|