| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366 |
- import http.cookies
- import json
- import os
- import stat
- import sys
- import typing
- from datetime import datetime
- from email.utils import format_datetime, formatdate
- from functools import partial
- from mimetypes import guess_type as mimetypes_guess_type
- from urllib.parse import quote
- import anyio
- from starlette._compat import md5_hexdigest
- from starlette.background import BackgroundTask
- from starlette.concurrency import iterate_in_threadpool
- from starlette.datastructures import URL, MutableHeaders
- from starlette.types import Receive, Scope, Send
- if sys.version_info >= (3, 8): # pragma: no cover
- from typing import Literal
- else: # pragma: no cover
- from typing_extensions import Literal
- # Workaround for adding samesite support to pre 3.8 python
- http.cookies.Morsel._reserved["samesite"] = "SameSite" # type: ignore[attr-defined]
- # Compatibility wrapper for `mimetypes.guess_type` to support `os.PathLike` on <py3.8
- def guess_type(
- url: typing.Union[str, "os.PathLike[str]"], strict: bool = True
- ) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]:
- if sys.version_info < (3, 8): # pragma: no cover
- url = os.fspath(url)
- return mimetypes_guess_type(url, strict)
- class Response:
- media_type = None
- charset = "utf-8"
- def __init__(
- self,
- content: typing.Any = None,
- status_code: int = 200,
- headers: typing.Optional[typing.Mapping[str, str]] = None,
- media_type: typing.Optional[str] = None,
- background: typing.Optional[BackgroundTask] = None,
- ) -> None:
- self.status_code = status_code
- if media_type is not None:
- self.media_type = media_type
- self.background = background
- self.body = self.render(content)
- self.init_headers(headers)
- def render(self, content: typing.Any) -> bytes:
- if content is None:
- return b""
- if isinstance(content, bytes):
- return content
- return content.encode(self.charset)
- def init_headers(
- self, headers: typing.Optional[typing.Mapping[str, str]] = None
- ) -> None:
- if headers is None:
- raw_headers: typing.List[typing.Tuple[bytes, bytes]] = []
- populate_content_length = True
- populate_content_type = True
- else:
- raw_headers = [
- (k.lower().encode("latin-1"), v.encode("latin-1"))
- for k, v in headers.items()
- ]
- keys = [h[0] for h in raw_headers]
- populate_content_length = b"content-length" not in keys
- populate_content_type = b"content-type" not in keys
- body = getattr(self, "body", None)
- if (
- body is not None
- and populate_content_length
- and not (self.status_code < 200 or self.status_code in (204, 304))
- ):
- content_length = str(len(body))
- raw_headers.append((b"content-length", content_length.encode("latin-1")))
- content_type = self.media_type
- if content_type is not None and populate_content_type:
- if content_type.startswith("text/"):
- content_type += "; charset=" + self.charset
- raw_headers.append((b"content-type", content_type.encode("latin-1")))
- self.raw_headers = raw_headers
- @property
- def headers(self) -> MutableHeaders:
- if not hasattr(self, "_headers"):
- self._headers = MutableHeaders(raw=self.raw_headers)
- return self._headers
- def set_cookie(
- self,
- key: str,
- value: str = "",
- max_age: typing.Optional[int] = None,
- expires: typing.Optional[typing.Union[datetime, str, int]] = None,
- path: str = "/",
- domain: typing.Optional[str] = None,
- secure: bool = False,
- httponly: bool = False,
- samesite: typing.Optional[Literal["lax", "strict", "none"]] = "lax",
- ) -> None:
- cookie: "http.cookies.BaseCookie[str]" = http.cookies.SimpleCookie()
- cookie[key] = value
- if max_age is not None:
- cookie[key]["max-age"] = max_age
- if expires is not None:
- if isinstance(expires, datetime):
- cookie[key]["expires"] = format_datetime(expires, usegmt=True)
- else:
- cookie[key]["expires"] = expires
- if path is not None:
- cookie[key]["path"] = path
- if domain is not None:
- cookie[key]["domain"] = domain
- if secure:
- cookie[key]["secure"] = True
- if httponly:
- cookie[key]["httponly"] = True
- if samesite is not None:
- assert samesite.lower() in [
- "strict",
- "lax",
- "none",
- ], "samesite must be either 'strict', 'lax' or 'none'"
- cookie[key]["samesite"] = samesite
- cookie_val = cookie.output(header="").strip()
- self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
- def delete_cookie(
- self,
- key: str,
- path: str = "/",
- domain: typing.Optional[str] = None,
- secure: bool = False,
- httponly: bool = False,
- samesite: typing.Optional[Literal["lax", "strict", "none"]] = "lax",
- ) -> None:
- self.set_cookie(
- key,
- max_age=0,
- expires=0,
- path=path,
- domain=domain,
- secure=secure,
- httponly=httponly,
- samesite=samesite,
- )
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- await send(
- {
- "type": "http.response.start",
- "status": self.status_code,
- "headers": self.raw_headers,
- }
- )
- await send({"type": "http.response.body", "body": self.body})
- if self.background is not None:
- await self.background()
- class HTMLResponse(Response):
- media_type = "text/html"
- class PlainTextResponse(Response):
- media_type = "text/plain"
- class JSONResponse(Response):
- media_type = "application/json"
- def __init__(
- self,
- content: typing.Any,
- status_code: int = 200,
- headers: typing.Optional[typing.Dict[str, str]] = None,
- media_type: typing.Optional[str] = None,
- background: typing.Optional[BackgroundTask] = None,
- ) -> None:
- super().__init__(content, status_code, headers, media_type, background)
- def render(self, content: typing.Any) -> bytes:
- return json.dumps(
- content,
- ensure_ascii=False,
- allow_nan=False,
- indent=None,
- separators=(",", ":"),
- ).encode("utf-8")
- class RedirectResponse(Response):
- def __init__(
- self,
- url: typing.Union[str, URL],
- status_code: int = 307,
- headers: typing.Optional[typing.Mapping[str, str]] = None,
- background: typing.Optional[BackgroundTask] = None,
- ) -> None:
- super().__init__(
- content=b"", status_code=status_code, headers=headers, background=background
- )
- self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
- Content = typing.Union[str, bytes]
- SyncContentStream = typing.Iterator[Content]
- AsyncContentStream = typing.AsyncIterable[Content]
- ContentStream = typing.Union[AsyncContentStream, SyncContentStream]
- class StreamingResponse(Response):
- body_iterator: AsyncContentStream
- def __init__(
- self,
- content: ContentStream,
- status_code: int = 200,
- headers: typing.Optional[typing.Mapping[str, str]] = None,
- media_type: typing.Optional[str] = None,
- background: typing.Optional[BackgroundTask] = None,
- ) -> None:
- if isinstance(content, typing.AsyncIterable):
- self.body_iterator = content
- else:
- self.body_iterator = iterate_in_threadpool(content)
- self.status_code = status_code
- self.media_type = self.media_type if media_type is None else media_type
- self.background = background
- self.init_headers(headers)
- async def listen_for_disconnect(self, receive: Receive) -> None:
- while True:
- message = await receive()
- if message["type"] == "http.disconnect":
- break
- async def stream_response(self, send: Send) -> None:
- await send(
- {
- "type": "http.response.start",
- "status": self.status_code,
- "headers": self.raw_headers,
- }
- )
- async for chunk in self.body_iterator:
- if not isinstance(chunk, bytes):
- chunk = chunk.encode(self.charset)
- await send({"type": "http.response.body", "body": chunk, "more_body": True})
- await send({"type": "http.response.body", "body": b"", "more_body": False})
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- async with anyio.create_task_group() as task_group:
- async def wrap(func: "typing.Callable[[], typing.Awaitable[None]]") -> None:
- await func()
- task_group.cancel_scope.cancel()
- task_group.start_soon(wrap, partial(self.stream_response, send))
- await wrap(partial(self.listen_for_disconnect, receive))
- if self.background is not None:
- await self.background()
- class FileResponse(Response):
- chunk_size = 64 * 1024
- def __init__(
- self,
- path: typing.Union[str, "os.PathLike[str]"],
- status_code: int = 200,
- headers: typing.Optional[typing.Mapping[str, str]] = None,
- media_type: typing.Optional[str] = None,
- background: typing.Optional[BackgroundTask] = None,
- filename: typing.Optional[str] = None,
- stat_result: typing.Optional[os.stat_result] = None,
- method: typing.Optional[str] = None,
- content_disposition_type: str = "attachment",
- ) -> None:
- self.path = path
- self.status_code = status_code
- self.filename = filename
- self.send_header_only = method is not None and method.upper() == "HEAD"
- if media_type is None:
- media_type = guess_type(filename or path)[0] or "text/plain"
- self.media_type = media_type
- self.background = background
- self.init_headers(headers)
- if self.filename is not None:
- content_disposition_filename = quote(self.filename)
- if content_disposition_filename != self.filename:
- content_disposition = "{}; filename*=utf-8''{}".format(
- content_disposition_type, content_disposition_filename
- )
- else:
- content_disposition = '{}; filename="{}"'.format(
- content_disposition_type, self.filename
- )
- self.headers.setdefault("content-disposition", content_disposition)
- self.stat_result = stat_result
- if stat_result is not None:
- self.set_stat_headers(stat_result)
- def set_stat_headers(self, stat_result: os.stat_result) -> None:
- content_length = str(stat_result.st_size)
- last_modified = formatdate(stat_result.st_mtime, usegmt=True)
- etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
- etag = md5_hexdigest(etag_base.encode(), usedforsecurity=False)
- self.headers.setdefault("content-length", content_length)
- self.headers.setdefault("last-modified", last_modified)
- self.headers.setdefault("etag", etag)
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- if self.stat_result is None:
- try:
- stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
- self.set_stat_headers(stat_result)
- except FileNotFoundError:
- raise RuntimeError(f"File at path {self.path} does not exist.")
- else:
- mode = stat_result.st_mode
- if not stat.S_ISREG(mode):
- raise RuntimeError(f"File at path {self.path} is not a file.")
- await send(
- {
- "type": "http.response.start",
- "status": self.status_code,
- "headers": self.raw_headers,
- }
- )
- if self.send_header_only:
- await send({"type": "http.response.body", "body": b"", "more_body": False})
- else:
- async with await anyio.open_file(self.path, mode="rb") as file:
- more_body = True
- while more_body:
- chunk = await file.read(self.chunk_size)
- more_body = len(chunk) == self.chunk_size
- await send(
- {
- "type": "http.response.body",
- "body": chunk,
- "more_body": more_body,
- }
- )
- if self.background is not None:
- await self.background()
|