responses.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. import http.cookies
  2. import json
  3. import os
  4. import stat
  5. import sys
  6. import typing
  7. from datetime import datetime
  8. from email.utils import format_datetime, formatdate
  9. from functools import partial
  10. from mimetypes import guess_type as mimetypes_guess_type
  11. from urllib.parse import quote
  12. import anyio
  13. from starlette._compat import md5_hexdigest
  14. from starlette.background import BackgroundTask
  15. from starlette.concurrency import iterate_in_threadpool
  16. from starlette.datastructures import URL, MutableHeaders
  17. from starlette.types import Receive, Scope, Send
  18. if sys.version_info >= (3, 8): # pragma: no cover
  19. from typing import Literal
  20. else: # pragma: no cover
  21. from typing_extensions import Literal
  22. # Workaround for adding samesite support to pre 3.8 python
  23. http.cookies.Morsel._reserved["samesite"] = "SameSite" # type: ignore[attr-defined]
  24. # Compatibility wrapper for `mimetypes.guess_type` to support `os.PathLike` on <py3.8
  25. def guess_type(
  26. url: typing.Union[str, "os.PathLike[str]"], strict: bool = True
  27. ) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]:
  28. if sys.version_info < (3, 8): # pragma: no cover
  29. url = os.fspath(url)
  30. return mimetypes_guess_type(url, strict)
  31. class Response:
  32. media_type = None
  33. charset = "utf-8"
  34. def __init__(
  35. self,
  36. content: typing.Any = None,
  37. status_code: int = 200,
  38. headers: typing.Optional[typing.Mapping[str, str]] = None,
  39. media_type: typing.Optional[str] = None,
  40. background: typing.Optional[BackgroundTask] = None,
  41. ) -> None:
  42. self.status_code = status_code
  43. if media_type is not None:
  44. self.media_type = media_type
  45. self.background = background
  46. self.body = self.render(content)
  47. self.init_headers(headers)
  48. def render(self, content: typing.Any) -> bytes:
  49. if content is None:
  50. return b""
  51. if isinstance(content, bytes):
  52. return content
  53. return content.encode(self.charset)
  54. def init_headers(
  55. self, headers: typing.Optional[typing.Mapping[str, str]] = None
  56. ) -> None:
  57. if headers is None:
  58. raw_headers: typing.List[typing.Tuple[bytes, bytes]] = []
  59. populate_content_length = True
  60. populate_content_type = True
  61. else:
  62. raw_headers = [
  63. (k.lower().encode("latin-1"), v.encode("latin-1"))
  64. for k, v in headers.items()
  65. ]
  66. keys = [h[0] for h in raw_headers]
  67. populate_content_length = b"content-length" not in keys
  68. populate_content_type = b"content-type" not in keys
  69. body = getattr(self, "body", None)
  70. if (
  71. body is not None
  72. and populate_content_length
  73. and not (self.status_code < 200 or self.status_code in (204, 304))
  74. ):
  75. content_length = str(len(body))
  76. raw_headers.append((b"content-length", content_length.encode("latin-1")))
  77. content_type = self.media_type
  78. if content_type is not None and populate_content_type:
  79. if content_type.startswith("text/"):
  80. content_type += "; charset=" + self.charset
  81. raw_headers.append((b"content-type", content_type.encode("latin-1")))
  82. self.raw_headers = raw_headers
  83. @property
  84. def headers(self) -> MutableHeaders:
  85. if not hasattr(self, "_headers"):
  86. self._headers = MutableHeaders(raw=self.raw_headers)
  87. return self._headers
  88. def set_cookie(
  89. self,
  90. key: str,
  91. value: str = "",
  92. max_age: typing.Optional[int] = None,
  93. expires: typing.Optional[typing.Union[datetime, str, int]] = None,
  94. path: str = "/",
  95. domain: typing.Optional[str] = None,
  96. secure: bool = False,
  97. httponly: bool = False,
  98. samesite: typing.Optional[Literal["lax", "strict", "none"]] = "lax",
  99. ) -> None:
  100. cookie: "http.cookies.BaseCookie[str]" = http.cookies.SimpleCookie()
  101. cookie[key] = value
  102. if max_age is not None:
  103. cookie[key]["max-age"] = max_age
  104. if expires is not None:
  105. if isinstance(expires, datetime):
  106. cookie[key]["expires"] = format_datetime(expires, usegmt=True)
  107. else:
  108. cookie[key]["expires"] = expires
  109. if path is not None:
  110. cookie[key]["path"] = path
  111. if domain is not None:
  112. cookie[key]["domain"] = domain
  113. if secure:
  114. cookie[key]["secure"] = True
  115. if httponly:
  116. cookie[key]["httponly"] = True
  117. if samesite is not None:
  118. assert samesite.lower() in [
  119. "strict",
  120. "lax",
  121. "none",
  122. ], "samesite must be either 'strict', 'lax' or 'none'"
  123. cookie[key]["samesite"] = samesite
  124. cookie_val = cookie.output(header="").strip()
  125. self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
  126. def delete_cookie(
  127. self,
  128. key: str,
  129. path: str = "/",
  130. domain: typing.Optional[str] = None,
  131. secure: bool = False,
  132. httponly: bool = False,
  133. samesite: typing.Optional[Literal["lax", "strict", "none"]] = "lax",
  134. ) -> None:
  135. self.set_cookie(
  136. key,
  137. max_age=0,
  138. expires=0,
  139. path=path,
  140. domain=domain,
  141. secure=secure,
  142. httponly=httponly,
  143. samesite=samesite,
  144. )
  145. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  146. await send(
  147. {
  148. "type": "http.response.start",
  149. "status": self.status_code,
  150. "headers": self.raw_headers,
  151. }
  152. )
  153. await send({"type": "http.response.body", "body": self.body})
  154. if self.background is not None:
  155. await self.background()
  156. class HTMLResponse(Response):
  157. media_type = "text/html"
  158. class PlainTextResponse(Response):
  159. media_type = "text/plain"
  160. class JSONResponse(Response):
  161. media_type = "application/json"
  162. def __init__(
  163. self,
  164. content: typing.Any,
  165. status_code: int = 200,
  166. headers: typing.Optional[typing.Dict[str, str]] = None,
  167. media_type: typing.Optional[str] = None,
  168. background: typing.Optional[BackgroundTask] = None,
  169. ) -> None:
  170. super().__init__(content, status_code, headers, media_type, background)
  171. def render(self, content: typing.Any) -> bytes:
  172. return json.dumps(
  173. content,
  174. ensure_ascii=False,
  175. allow_nan=False,
  176. indent=None,
  177. separators=(",", ":"),
  178. ).encode("utf-8")
  179. class RedirectResponse(Response):
  180. def __init__(
  181. self,
  182. url: typing.Union[str, URL],
  183. status_code: int = 307,
  184. headers: typing.Optional[typing.Mapping[str, str]] = None,
  185. background: typing.Optional[BackgroundTask] = None,
  186. ) -> None:
  187. super().__init__(
  188. content=b"", status_code=status_code, headers=headers, background=background
  189. )
  190. self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
  191. Content = typing.Union[str, bytes]
  192. SyncContentStream = typing.Iterator[Content]
  193. AsyncContentStream = typing.AsyncIterable[Content]
  194. ContentStream = typing.Union[AsyncContentStream, SyncContentStream]
  195. class StreamingResponse(Response):
  196. body_iterator: AsyncContentStream
  197. def __init__(
  198. self,
  199. content: ContentStream,
  200. status_code: int = 200,
  201. headers: typing.Optional[typing.Mapping[str, str]] = None,
  202. media_type: typing.Optional[str] = None,
  203. background: typing.Optional[BackgroundTask] = None,
  204. ) -> None:
  205. if isinstance(content, typing.AsyncIterable):
  206. self.body_iterator = content
  207. else:
  208. self.body_iterator = iterate_in_threadpool(content)
  209. self.status_code = status_code
  210. self.media_type = self.media_type if media_type is None else media_type
  211. self.background = background
  212. self.init_headers(headers)
  213. async def listen_for_disconnect(self, receive: Receive) -> None:
  214. while True:
  215. message = await receive()
  216. if message["type"] == "http.disconnect":
  217. break
  218. async def stream_response(self, send: Send) -> None:
  219. await send(
  220. {
  221. "type": "http.response.start",
  222. "status": self.status_code,
  223. "headers": self.raw_headers,
  224. }
  225. )
  226. async for chunk in self.body_iterator:
  227. if not isinstance(chunk, bytes):
  228. chunk = chunk.encode(self.charset)
  229. await send({"type": "http.response.body", "body": chunk, "more_body": True})
  230. await send({"type": "http.response.body", "body": b"", "more_body": False})
  231. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  232. async with anyio.create_task_group() as task_group:
  233. async def wrap(func: "typing.Callable[[], typing.Awaitable[None]]") -> None:
  234. await func()
  235. task_group.cancel_scope.cancel()
  236. task_group.start_soon(wrap, partial(self.stream_response, send))
  237. await wrap(partial(self.listen_for_disconnect, receive))
  238. if self.background is not None:
  239. await self.background()
  240. class FileResponse(Response):
  241. chunk_size = 64 * 1024
  242. def __init__(
  243. self,
  244. path: typing.Union[str, "os.PathLike[str]"],
  245. status_code: int = 200,
  246. headers: typing.Optional[typing.Mapping[str, str]] = None,
  247. media_type: typing.Optional[str] = None,
  248. background: typing.Optional[BackgroundTask] = None,
  249. filename: typing.Optional[str] = None,
  250. stat_result: typing.Optional[os.stat_result] = None,
  251. method: typing.Optional[str] = None,
  252. content_disposition_type: str = "attachment",
  253. ) -> None:
  254. self.path = path
  255. self.status_code = status_code
  256. self.filename = filename
  257. self.send_header_only = method is not None and method.upper() == "HEAD"
  258. if media_type is None:
  259. media_type = guess_type(filename or path)[0] or "text/plain"
  260. self.media_type = media_type
  261. self.background = background
  262. self.init_headers(headers)
  263. if self.filename is not None:
  264. content_disposition_filename = quote(self.filename)
  265. if content_disposition_filename != self.filename:
  266. content_disposition = "{}; filename*=utf-8''{}".format(
  267. content_disposition_type, content_disposition_filename
  268. )
  269. else:
  270. content_disposition = '{}; filename="{}"'.format(
  271. content_disposition_type, self.filename
  272. )
  273. self.headers.setdefault("content-disposition", content_disposition)
  274. self.stat_result = stat_result
  275. if stat_result is not None:
  276. self.set_stat_headers(stat_result)
  277. def set_stat_headers(self, stat_result: os.stat_result) -> None:
  278. content_length = str(stat_result.st_size)
  279. last_modified = formatdate(stat_result.st_mtime, usegmt=True)
  280. etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
  281. etag = md5_hexdigest(etag_base.encode(), usedforsecurity=False)
  282. self.headers.setdefault("content-length", content_length)
  283. self.headers.setdefault("last-modified", last_modified)
  284. self.headers.setdefault("etag", etag)
  285. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  286. if self.stat_result is None:
  287. try:
  288. stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
  289. self.set_stat_headers(stat_result)
  290. except FileNotFoundError:
  291. raise RuntimeError(f"File at path {self.path} does not exist.")
  292. else:
  293. mode = stat_result.st_mode
  294. if not stat.S_ISREG(mode):
  295. raise RuntimeError(f"File at path {self.path} is not a file.")
  296. await send(
  297. {
  298. "type": "http.response.start",
  299. "status": self.status_code,
  300. "headers": self.raw_headers,
  301. }
  302. )
  303. if self.send_header_only:
  304. await send({"type": "http.response.body", "body": b"", "more_body": False})
  305. else:
  306. async with await anyio.open_file(self.path, mode="rb") as file:
  307. more_body = True
  308. while more_body:
  309. chunk = await file.read(self.chunk_size)
  310. more_body = len(chunk) == self.chunk_size
  311. await send(
  312. {
  313. "type": "http.response.body",
  314. "body": chunk,
  315. "more_body": more_body,
  316. }
  317. )
  318. if self.background is not None:
  319. await self.background()