| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862 |
- import contextlib
- import functools
- import inspect
- import re
- import traceback
- import types
- import typing
- import warnings
- from contextlib import asynccontextmanager
- from enum import Enum
- from starlette._utils import is_async_callable
- from starlette.concurrency import run_in_threadpool
- from starlette.convertors import CONVERTOR_TYPES, Convertor
- from starlette.datastructures import URL, Headers, URLPath
- from starlette.exceptions import HTTPException
- from starlette.middleware import Middleware
- from starlette.requests import Request
- from starlette.responses import PlainTextResponse, RedirectResponse
- from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
- from starlette.websockets import WebSocket, WebSocketClose
- class NoMatchFound(Exception):
- """
- Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)`
- if no matching route exists.
- """
- def __init__(self, name: str, path_params: typing.Dict[str, typing.Any]) -> None:
- params = ", ".join(list(path_params.keys()))
- super().__init__(f'No route exists for name "{name}" and params "{params}".')
- class Match(Enum):
- NONE = 0
- PARTIAL = 1
- FULL = 2
- def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
- """
- Correctly determines if an object is a coroutine function,
- including those wrapped in functools.partial objects.
- """
- warnings.warn(
- "iscoroutinefunction_or_partial is deprecated, "
- "and will be removed in a future release.",
- DeprecationWarning,
- )
- while isinstance(obj, functools.partial):
- obj = obj.func
- return inspect.iscoroutinefunction(obj)
- def request_response(func: typing.Callable) -> ASGIApp:
- """
- Takes a function or coroutine `func(request) -> response`,
- and returns an ASGI application.
- """
- is_coroutine = is_async_callable(func)
- async def app(scope: Scope, receive: Receive, send: Send) -> None:
- request = Request(scope, receive=receive, send=send)
- if is_coroutine:
- response = await func(request)
- else:
- response = await run_in_threadpool(func, request)
- await response(scope, receive, send)
- return app
- def websocket_session(func: typing.Callable) -> ASGIApp:
- """
- Takes a coroutine `func(session)`, and returns an ASGI application.
- """
- # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async"
- async def app(scope: Scope, receive: Receive, send: Send) -> None:
- session = WebSocket(scope, receive=receive, send=send)
- await func(session)
- return app
- def get_name(endpoint: typing.Callable) -> str:
- if inspect.isroutine(endpoint) or inspect.isclass(endpoint):
- return endpoint.__name__
- return endpoint.__class__.__name__
- def replace_params(
- path: str,
- param_convertors: typing.Dict[str, Convertor],
- path_params: typing.Dict[str, str],
- ) -> typing.Tuple[str, dict]:
- for key, value in list(path_params.items()):
- if "{" + key + "}" in path:
- convertor = param_convertors[key]
- value = convertor.to_string(value)
- path = path.replace("{" + key + "}", value)
- path_params.pop(key)
- return path, path_params
- # Match parameters in URL paths, eg. '{param}', and '{param:int}'
- PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
- def compile_path(
- path: str,
- ) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]:
- """
- Given a path string, like: "/{username:str}",
- or a host string, like: "{subdomain}.mydomain.org", return a three-tuple
- of (regex, format, {param_name:convertor}).
- regex: "/(?P<username>[^/]+)"
- format: "/{username}"
- convertors: {"username": StringConvertor()}
- """
- is_host = not path.startswith("/")
- path_regex = "^"
- path_format = ""
- duplicated_params = set()
- idx = 0
- param_convertors = {}
- for match in PARAM_REGEX.finditer(path):
- param_name, convertor_type = match.groups("str")
- convertor_type = convertor_type.lstrip(":")
- assert (
- convertor_type in CONVERTOR_TYPES
- ), f"Unknown path convertor '{convertor_type}'"
- convertor = CONVERTOR_TYPES[convertor_type]
- path_regex += re.escape(path[idx : match.start()])
- path_regex += f"(?P<{param_name}>{convertor.regex})"
- path_format += path[idx : match.start()]
- path_format += "{%s}" % param_name
- if param_name in param_convertors:
- duplicated_params.add(param_name)
- param_convertors[param_name] = convertor
- idx = match.end()
- if duplicated_params:
- names = ", ".join(sorted(duplicated_params))
- ending = "s" if len(duplicated_params) > 1 else ""
- raise ValueError(f"Duplicated param name{ending} {names} at path {path}")
- if is_host:
- # Align with `Host.matches()` behavior, which ignores port.
- hostname = path[idx:].split(":")[0]
- path_regex += re.escape(hostname) + "$"
- else:
- path_regex += re.escape(path[idx:]) + "$"
- path_format += path[idx:]
- return re.compile(path_regex), path_format, param_convertors
- class BaseRoute:
- def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
- raise NotImplementedError() # pragma: no cover
- def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
- raise NotImplementedError() # pragma: no cover
- async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
- raise NotImplementedError() # pragma: no cover
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- """
- A route may be used in isolation as a stand-alone ASGI app.
- This is a somewhat contrived case, as they'll almost always be used
- within a Router, but could be useful for some tooling and minimal apps.
- """
- match, child_scope = self.matches(scope)
- if match == Match.NONE:
- if scope["type"] == "http":
- response = PlainTextResponse("Not Found", status_code=404)
- await response(scope, receive, send)
- elif scope["type"] == "websocket":
- websocket_close = WebSocketClose()
- await websocket_close(scope, receive, send)
- return
- scope.update(child_scope)
- await self.handle(scope, receive, send)
- class Route(BaseRoute):
- def __init__(
- self,
- path: str,
- endpoint: typing.Callable,
- *,
- methods: typing.Optional[typing.List[str]] = None,
- name: typing.Optional[str] = None,
- include_in_schema: bool = True,
- ) -> None:
- assert path.startswith("/"), "Routed paths must start with '/'"
- self.path = path
- self.endpoint = endpoint
- self.name = get_name(endpoint) if name is None else name
- self.include_in_schema = include_in_schema
- endpoint_handler = endpoint
- while isinstance(endpoint_handler, functools.partial):
- endpoint_handler = endpoint_handler.func
- if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
- # Endpoint is function or method. Treat it as `func(request) -> response`.
- self.app = request_response(endpoint)
- if methods is None:
- methods = ["GET"]
- else:
- # Endpoint is a class. Treat it as ASGI.
- self.app = endpoint
- if methods is None:
- self.methods = None
- else:
- self.methods = {method.upper() for method in methods}
- if "GET" in self.methods:
- self.methods.add("HEAD")
- self.path_regex, self.path_format, self.param_convertors = compile_path(path)
- def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
- if scope["type"] == "http":
- match = self.path_regex.match(scope["path"])
- if match:
- matched_params = match.groupdict()
- for key, value in matched_params.items():
- matched_params[key] = self.param_convertors[key].convert(value)
- path_params = dict(scope.get("path_params", {}))
- path_params.update(matched_params)
- child_scope = {"endpoint": self.endpoint, "path_params": path_params}
- if self.methods and scope["method"] not in self.methods:
- return Match.PARTIAL, child_scope
- else:
- return Match.FULL, child_scope
- return Match.NONE, {}
- def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
- seen_params = set(path_params.keys())
- expected_params = set(self.param_convertors.keys())
- if __name != self.name or seen_params != expected_params:
- raise NoMatchFound(__name, path_params)
- path, remaining_params = replace_params(
- self.path_format, self.param_convertors, path_params
- )
- assert not remaining_params
- return URLPath(path=path, protocol="http")
- async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
- if self.methods and scope["method"] not in self.methods:
- headers = {"Allow": ", ".join(self.methods)}
- if "app" in scope:
- raise HTTPException(status_code=405, headers=headers)
- else:
- response = PlainTextResponse(
- "Method Not Allowed", status_code=405, headers=headers
- )
- await response(scope, receive, send)
- else:
- await self.app(scope, receive, send)
- def __eq__(self, other: typing.Any) -> bool:
- return (
- isinstance(other, Route)
- and self.path == other.path
- and self.endpoint == other.endpoint
- and self.methods == other.methods
- )
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- methods = sorted(self.methods or [])
- path, name = self.path, self.name
- return f"{class_name}(path={path!r}, name={name!r}, methods={methods!r})"
- class WebSocketRoute(BaseRoute):
- def __init__(
- self, path: str, endpoint: typing.Callable, *, name: typing.Optional[str] = None
- ) -> None:
- assert path.startswith("/"), "Routed paths must start with '/'"
- self.path = path
- self.endpoint = endpoint
- self.name = get_name(endpoint) if name is None else name
- endpoint_handler = endpoint
- while isinstance(endpoint_handler, functools.partial):
- endpoint_handler = endpoint_handler.func
- if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
- # Endpoint is function or method. Treat it as `func(websocket)`.
- self.app = websocket_session(endpoint)
- else:
- # Endpoint is a class. Treat it as ASGI.
- self.app = endpoint
- self.path_regex, self.path_format, self.param_convertors = compile_path(path)
- def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
- if scope["type"] == "websocket":
- match = self.path_regex.match(scope["path"])
- if match:
- matched_params = match.groupdict()
- for key, value in matched_params.items():
- matched_params[key] = self.param_convertors[key].convert(value)
- path_params = dict(scope.get("path_params", {}))
- path_params.update(matched_params)
- child_scope = {"endpoint": self.endpoint, "path_params": path_params}
- return Match.FULL, child_scope
- return Match.NONE, {}
- def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
- seen_params = set(path_params.keys())
- expected_params = set(self.param_convertors.keys())
- if __name != self.name or seen_params != expected_params:
- raise NoMatchFound(__name, path_params)
- path, remaining_params = replace_params(
- self.path_format, self.param_convertors, path_params
- )
- assert not remaining_params
- return URLPath(path=path, protocol="websocket")
- async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
- await self.app(scope, receive, send)
- def __eq__(self, other: typing.Any) -> bool:
- return (
- isinstance(other, WebSocketRoute)
- and self.path == other.path
- and self.endpoint == other.endpoint
- )
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})"
- class Mount(BaseRoute):
- def __init__(
- self,
- path: str,
- app: typing.Optional[ASGIApp] = None,
- routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
- name: typing.Optional[str] = None,
- *,
- middleware: typing.Optional[typing.Sequence[Middleware]] = None,
- ) -> None:
- assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
- assert (
- app is not None or routes is not None
- ), "Either 'app=...', or 'routes=' must be specified"
- self.path = path.rstrip("/")
- if app is not None:
- self._base_app: ASGIApp = app
- else:
- self._base_app = Router(routes=routes)
- self.app = self._base_app
- if middleware is not None:
- for cls, options in reversed(middleware):
- self.app = cls(app=self.app, **options)
- self.name = name
- self.path_regex, self.path_format, self.param_convertors = compile_path(
- self.path + "/{path:path}"
- )
- @property
- def routes(self) -> typing.List[BaseRoute]:
- return getattr(self._base_app, "routes", [])
- def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
- if scope["type"] in ("http", "websocket"):
- path = scope["path"]
- match = self.path_regex.match(path)
- if match:
- matched_params = match.groupdict()
- for key, value in matched_params.items():
- matched_params[key] = self.param_convertors[key].convert(value)
- remaining_path = "/" + matched_params.pop("path")
- matched_path = path[: -len(remaining_path)]
- path_params = dict(scope.get("path_params", {}))
- path_params.update(matched_params)
- root_path = scope.get("root_path", "")
- child_scope = {
- "path_params": path_params,
- "app_root_path": scope.get("app_root_path", root_path),
- "root_path": root_path + matched_path,
- "path": remaining_path,
- "endpoint": self.app,
- }
- return Match.FULL, child_scope
- return Match.NONE, {}
- def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
- if self.name is not None and __name == self.name and "path" in path_params:
- # 'name' matches "<mount_name>".
- path_params["path"] = path_params["path"].lstrip("/")
- path, remaining_params = replace_params(
- self.path_format, self.param_convertors, path_params
- )
- if not remaining_params:
- return URLPath(path=path)
- elif self.name is None or __name.startswith(self.name + ":"):
- if self.name is None:
- # No mount name.
- remaining_name = __name
- else:
- # 'name' matches "<mount_name>:<child_name>".
- remaining_name = __name[len(self.name) + 1 :]
- path_kwarg = path_params.get("path")
- path_params["path"] = ""
- path_prefix, remaining_params = replace_params(
- self.path_format, self.param_convertors, path_params
- )
- if path_kwarg is not None:
- remaining_params["path"] = path_kwarg
- for route in self.routes or []:
- try:
- url = route.url_path_for(remaining_name, **remaining_params)
- return URLPath(
- path=path_prefix.rstrip("/") + str(url), protocol=url.protocol
- )
- except NoMatchFound:
- pass
- raise NoMatchFound(__name, path_params)
- async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
- await self.app(scope, receive, send)
- def __eq__(self, other: typing.Any) -> bool:
- return (
- isinstance(other, Mount)
- and self.path == other.path
- and self.app == other.app
- )
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- name = self.name or ""
- return f"{class_name}(path={self.path!r}, name={name!r}, app={self.app!r})"
- class Host(BaseRoute):
- def __init__(
- self, host: str, app: ASGIApp, name: typing.Optional[str] = None
- ) -> None:
- assert not host.startswith("/"), "Host must not start with '/'"
- self.host = host
- self.app = app
- self.name = name
- self.host_regex, self.host_format, self.param_convertors = compile_path(host)
- @property
- def routes(self) -> typing.List[BaseRoute]:
- return getattr(self.app, "routes", [])
- def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
- if scope["type"] in ("http", "websocket"):
- headers = Headers(scope=scope)
- host = headers.get("host", "").split(":")[0]
- match = self.host_regex.match(host)
- if match:
- matched_params = match.groupdict()
- for key, value in matched_params.items():
- matched_params[key] = self.param_convertors[key].convert(value)
- path_params = dict(scope.get("path_params", {}))
- path_params.update(matched_params)
- child_scope = {"path_params": path_params, "endpoint": self.app}
- return Match.FULL, child_scope
- return Match.NONE, {}
- def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
- if self.name is not None and __name == self.name and "path" in path_params:
- # 'name' matches "<mount_name>".
- path = path_params.pop("path")
- host, remaining_params = replace_params(
- self.host_format, self.param_convertors, path_params
- )
- if not remaining_params:
- return URLPath(path=path, host=host)
- elif self.name is None or __name.startswith(self.name + ":"):
- if self.name is None:
- # No mount name.
- remaining_name = __name
- else:
- # 'name' matches "<mount_name>:<child_name>".
- remaining_name = __name[len(self.name) + 1 :]
- host, remaining_params = replace_params(
- self.host_format, self.param_convertors, path_params
- )
- for route in self.routes or []:
- try:
- url = route.url_path_for(remaining_name, **remaining_params)
- return URLPath(path=str(url), protocol=url.protocol, host=host)
- except NoMatchFound:
- pass
- raise NoMatchFound(__name, path_params)
- async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
- await self.app(scope, receive, send)
- def __eq__(self, other: typing.Any) -> bool:
- return (
- isinstance(other, Host)
- and self.host == other.host
- and self.app == other.app
- )
- def __repr__(self) -> str:
- class_name = self.__class__.__name__
- name = self.name or ""
- return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})"
- _T = typing.TypeVar("_T")
- class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
- def __init__(self, cm: typing.ContextManager[_T]):
- self._cm = cm
- async def __aenter__(self) -> _T:
- return self._cm.__enter__()
- async def __aexit__(
- self,
- exc_type: typing.Optional[typing.Type[BaseException]],
- exc_value: typing.Optional[BaseException],
- traceback: typing.Optional[types.TracebackType],
- ) -> typing.Optional[bool]:
- return self._cm.__exit__(exc_type, exc_value, traceback)
- def _wrap_gen_lifespan_context(
- lifespan_context: typing.Callable[[typing.Any], typing.Generator]
- ) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
- cmgr = contextlib.contextmanager(lifespan_context)
- @functools.wraps(cmgr)
- def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
- return _AsyncLiftContextManager(cmgr(app))
- return wrapper
- class _DefaultLifespan:
- def __init__(self, router: "Router"):
- self._router = router
- async def __aenter__(self) -> None:
- await self._router.startup()
- async def __aexit__(self, *exc_info: object) -> None:
- await self._router.shutdown()
- def __call__(self: _T, app: object) -> _T:
- return self
- class Router:
- def __init__(
- self,
- routes: typing.Optional[typing.Sequence[BaseRoute]] = None,
- redirect_slashes: bool = True,
- default: typing.Optional[ASGIApp] = None,
- on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None,
- on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None,
- # the generic to Lifespan[AppType] is the type of the top level application
- # which the router cannot know statically, so we use typing.Any
- lifespan: typing.Optional[Lifespan[typing.Any]] = None,
- ) -> None:
- self.routes = [] if routes is None else list(routes)
- self.redirect_slashes = redirect_slashes
- self.default = self.not_found if default is None else default
- self.on_startup = [] if on_startup is None else list(on_startup)
- self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)
- if on_startup or on_shutdown:
- warnings.warn(
- "The on_startup and on_shutdown parameters are deprecated, and they "
- "will be removed on version 1.0. Use the lifespan parameter instead. "
- "See more about it on https://www.starlette.io/lifespan/.",
- DeprecationWarning,
- )
- if lifespan is None:
- self.lifespan_context: Lifespan = _DefaultLifespan(self)
- elif inspect.isasyncgenfunction(lifespan):
- warnings.warn(
- "async generator function lifespans are deprecated, "
- "use an @contextlib.asynccontextmanager function instead",
- DeprecationWarning,
- )
- self.lifespan_context = asynccontextmanager(
- lifespan, # type: ignore[arg-type]
- )
- elif inspect.isgeneratorfunction(lifespan):
- warnings.warn(
- "generator function lifespans are deprecated, "
- "use an @contextlib.asynccontextmanager function instead",
- DeprecationWarning,
- )
- self.lifespan_context = _wrap_gen_lifespan_context(
- lifespan, # type: ignore[arg-type]
- )
- else:
- self.lifespan_context = lifespan
- async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
- if scope["type"] == "websocket":
- websocket_close = WebSocketClose()
- await websocket_close(scope, receive, send)
- return
- # If we're running inside a starlette application then raise an
- # exception, so that the configurable exception handler can deal with
- # returning the response. For plain ASGI apps, just return the response.
- if "app" in scope:
- raise HTTPException(status_code=404)
- else:
- response = PlainTextResponse("Not Found", status_code=404)
- await response(scope, receive, send)
- def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
- for route in self.routes:
- try:
- return route.url_path_for(__name, **path_params)
- except NoMatchFound:
- pass
- raise NoMatchFound(__name, path_params)
- async def startup(self) -> None:
- """
- Run any `.on_startup` event handlers.
- """
- for handler in self.on_startup:
- if is_async_callable(handler):
- await handler()
- else:
- handler()
- async def shutdown(self) -> None:
- """
- Run any `.on_shutdown` event handlers.
- """
- for handler in self.on_shutdown:
- if is_async_callable(handler):
- await handler()
- else:
- handler()
- async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
- """
- Handle ASGI lifespan messages, which allows us to manage application
- startup and shutdown events.
- """
- started = False
- app: typing.Any = scope.get("app")
- await receive()
- try:
- async with self.lifespan_context(app) as maybe_state:
- if maybe_state is not None:
- if "state" not in scope:
- raise RuntimeError(
- 'The server does not support "state" in the lifespan scope.'
- )
- scope["state"].update(maybe_state)
- await send({"type": "lifespan.startup.complete"})
- started = True
- await receive()
- except BaseException:
- exc_text = traceback.format_exc()
- if started:
- await send({"type": "lifespan.shutdown.failed", "message": exc_text})
- else:
- await send({"type": "lifespan.startup.failed", "message": exc_text})
- raise
- else:
- await send({"type": "lifespan.shutdown.complete"})
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- """
- The main entry point to the Router class.
- """
- assert scope["type"] in ("http", "websocket", "lifespan")
- if "router" not in scope:
- scope["router"] = self
- if scope["type"] == "lifespan":
- await self.lifespan(scope, receive, send)
- return
- partial = None
- for route in self.routes:
- # Determine if any route matches the incoming scope,
- # and hand over to the matching route if found.
- match, child_scope = route.matches(scope)
- if match == Match.FULL:
- scope.update(child_scope)
- await route.handle(scope, receive, send)
- return
- elif match == Match.PARTIAL and partial is None:
- partial = route
- partial_scope = child_scope
- if partial is not None:
- # Handle partial matches. These are cases where an endpoint is
- # able to handle the request, but is not a preferred option.
- # We use this in particular to deal with "405 Method Not Allowed".
- scope.update(partial_scope)
- await partial.handle(scope, receive, send)
- return
- if scope["type"] == "http" and self.redirect_slashes and scope["path"] != "/":
- redirect_scope = dict(scope)
- if scope["path"].endswith("/"):
- redirect_scope["path"] = redirect_scope["path"].rstrip("/")
- else:
- redirect_scope["path"] = redirect_scope["path"] + "/"
- for route in self.routes:
- match, child_scope = route.matches(redirect_scope)
- if match != Match.NONE:
- redirect_url = URL(scope=redirect_scope)
- response = RedirectResponse(url=str(redirect_url))
- await response(scope, receive, send)
- return
- await self.default(scope, receive, send)
- def __eq__(self, other: typing.Any) -> bool:
- return isinstance(other, Router) and self.routes == other.routes
- def mount(
- self, path: str, app: ASGIApp, name: typing.Optional[str] = None
- ) -> None: # pragma: nocover
- route = Mount(path, app=app, name=name)
- self.routes.append(route)
- def host(
- self, host: str, app: ASGIApp, name: typing.Optional[str] = None
- ) -> None: # pragma: no cover
- route = Host(host, app=app, name=name)
- self.routes.append(route)
- def add_route(
- self,
- path: str,
- endpoint: typing.Callable,
- methods: typing.Optional[typing.List[str]] = None,
- name: typing.Optional[str] = None,
- include_in_schema: bool = True,
- ) -> None: # pragma: nocover
- route = Route(
- path,
- endpoint=endpoint,
- methods=methods,
- name=name,
- include_in_schema=include_in_schema,
- )
- self.routes.append(route)
- def add_websocket_route(
- self, path: str, endpoint: typing.Callable, name: typing.Optional[str] = None
- ) -> None: # pragma: no cover
- route = WebSocketRoute(path, endpoint=endpoint, name=name)
- self.routes.append(route)
- def route(
- self,
- path: str,
- methods: typing.Optional[typing.List[str]] = None,
- name: typing.Optional[str] = None,
- include_in_schema: bool = True,
- ) -> typing.Callable:
- """
- We no longer document this decorator style API, and its usage is discouraged.
- Instead you should use the following approach:
- >>> routes = [Route(path, endpoint=...), ...]
- >>> app = Starlette(routes=routes)
- """
- warnings.warn(
- "The `route` decorator is deprecated, and will be removed in version 1.0.0."
- "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.", # noqa: E501
- DeprecationWarning,
- )
- def decorator(func: typing.Callable) -> typing.Callable:
- self.add_route(
- path,
- func,
- methods=methods,
- name=name,
- include_in_schema=include_in_schema,
- )
- return func
- return decorator
- def websocket_route(
- self, path: str, name: typing.Optional[str] = None
- ) -> typing.Callable:
- """
- We no longer document this decorator style API, and its usage is discouraged.
- Instead you should use the following approach:
- >>> routes = [WebSocketRoute(path, endpoint=...), ...]
- >>> app = Starlette(routes=routes)
- """
- warnings.warn(
- "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to " # noqa: E501
- "https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501
- DeprecationWarning,
- )
- def decorator(func: typing.Callable) -> typing.Callable:
- self.add_websocket_route(path, func, name=name)
- return func
- return decorator
- def add_event_handler(
- self, event_type: str, func: typing.Callable
- ) -> None: # pragma: no cover
- assert event_type in ("startup", "shutdown")
- if event_type == "startup":
- self.on_startup.append(func)
- else:
- self.on_shutdown.append(func)
- def on_event(self, event_type: str) -> typing.Callable:
- warnings.warn(
- "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
- "Refer to https://www.starlette.io/lifespan/ for recommended approach.",
- DeprecationWarning,
- )
- def decorator(func: typing.Callable) -> typing.Callable:
- self.add_event_handler(event_type, func)
- return func
- return decorator
|