requests.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. import json
  2. import typing
  3. from http import cookies as http_cookies
  4. import anyio
  5. from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
  6. from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
  7. from starlette.exceptions import HTTPException
  8. from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
  9. from starlette.types import Message, Receive, Scope, Send
  10. try:
  11. from multipart.multipart import parse_options_header
  12. except ModuleNotFoundError: # pragma: nocover
  13. parse_options_header = None
  14. if typing.TYPE_CHECKING:
  15. from starlette.routing import Router
  16. SERVER_PUSH_HEADERS_TO_COPY = {
  17. "accept",
  18. "accept-encoding",
  19. "accept-language",
  20. "cache-control",
  21. "user-agent",
  22. }
  23. def cookie_parser(cookie_string: str) -> typing.Dict[str, str]:
  24. """
  25. This function parses a ``Cookie`` HTTP header into a dict of key/value pairs.
  26. It attempts to mimic browser cookie parsing behavior: browsers and web servers
  27. frequently disregard the spec (RFC 6265) when setting and reading cookies,
  28. so we attempt to suit the common scenarios here.
  29. This function has been adapted from Django 3.1.0.
  30. Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based
  31. on an outdated spec and will fail on lots of input we want to support
  32. """
  33. cookie_dict: typing.Dict[str, str] = {}
  34. for chunk in cookie_string.split(";"):
  35. if "=" in chunk:
  36. key, val = chunk.split("=", 1)
  37. else:
  38. # Assume an empty name per
  39. # https://bugzilla.mozilla.org/show_bug.cgi?id=169091
  40. key, val = "", chunk
  41. key, val = key.strip(), val.strip()
  42. if key or val:
  43. # unquote using Python's algorithm.
  44. cookie_dict[key] = http_cookies._unquote(val)
  45. return cookie_dict
  46. class ClientDisconnect(Exception):
  47. pass
  48. class HTTPConnection(typing.Mapping[str, typing.Any]):
  49. """
  50. A base class for incoming HTTP connections, that is used to provide
  51. any functionality that is common to both `Request` and `WebSocket`.
  52. """
  53. def __init__(self, scope: Scope, receive: typing.Optional[Receive] = None) -> None:
  54. assert scope["type"] in ("http", "websocket")
  55. self.scope = scope
  56. def __getitem__(self, key: str) -> typing.Any:
  57. return self.scope[key]
  58. def __iter__(self) -> typing.Iterator[str]:
  59. return iter(self.scope)
  60. def __len__(self) -> int:
  61. return len(self.scope)
  62. # Don't use the `abc.Mapping.__eq__` implementation.
  63. # Connection instances should never be considered equal
  64. # unless `self is other`.
  65. __eq__ = object.__eq__
  66. __hash__ = object.__hash__
  67. @property
  68. def app(self) -> typing.Any:
  69. return self.scope["app"]
  70. @property
  71. def url(self) -> URL:
  72. if not hasattr(self, "_url"):
  73. self._url = URL(scope=self.scope)
  74. return self._url
  75. @property
  76. def base_url(self) -> URL:
  77. if not hasattr(self, "_base_url"):
  78. base_url_scope = dict(self.scope)
  79. base_url_scope["path"] = "/"
  80. base_url_scope["query_string"] = b""
  81. base_url_scope["root_path"] = base_url_scope.get(
  82. "app_root_path", base_url_scope.get("root_path", "")
  83. )
  84. self._base_url = URL(scope=base_url_scope)
  85. return self._base_url
  86. @property
  87. def headers(self) -> Headers:
  88. if not hasattr(self, "_headers"):
  89. self._headers = Headers(scope=self.scope)
  90. return self._headers
  91. @property
  92. def query_params(self) -> QueryParams:
  93. if not hasattr(self, "_query_params"):
  94. self._query_params = QueryParams(self.scope["query_string"])
  95. return self._query_params
  96. @property
  97. def path_params(self) -> typing.Dict[str, typing.Any]:
  98. return self.scope.get("path_params", {})
  99. @property
  100. def cookies(self) -> typing.Dict[str, str]:
  101. if not hasattr(self, "_cookies"):
  102. cookies: typing.Dict[str, str] = {}
  103. cookie_header = self.headers.get("cookie")
  104. if cookie_header:
  105. cookies = cookie_parser(cookie_header)
  106. self._cookies = cookies
  107. return self._cookies
  108. @property
  109. def client(self) -> typing.Optional[Address]:
  110. # client is a 2 item tuple of (host, port), None or missing
  111. host_port = self.scope.get("client")
  112. if host_port is not None:
  113. return Address(*host_port)
  114. return None
  115. @property
  116. def session(self) -> typing.Dict[str, typing.Any]:
  117. assert (
  118. "session" in self.scope
  119. ), "SessionMiddleware must be installed to access request.session"
  120. return self.scope["session"]
  121. @property
  122. def auth(self) -> typing.Any:
  123. assert (
  124. "auth" in self.scope
  125. ), "AuthenticationMiddleware must be installed to access request.auth"
  126. return self.scope["auth"]
  127. @property
  128. def user(self) -> typing.Any:
  129. assert (
  130. "user" in self.scope
  131. ), "AuthenticationMiddleware must be installed to access request.user"
  132. return self.scope["user"]
  133. @property
  134. def state(self) -> State:
  135. if not hasattr(self, "_state"):
  136. # Ensure 'state' has an empty dict if it's not already populated.
  137. self.scope.setdefault("state", {})
  138. # Create a state instance with a reference to the dict in which it should
  139. # store info
  140. self._state = State(self.scope["state"])
  141. return self._state
  142. def url_for(self, __name: str, **path_params: typing.Any) -> URL:
  143. router: Router = self.scope["router"]
  144. url_path = router.url_path_for(__name, **path_params)
  145. return url_path.make_absolute_url(base_url=self.base_url)
  146. async def empty_receive() -> typing.NoReturn:
  147. raise RuntimeError("Receive channel has not been made available")
  148. async def empty_send(message: Message) -> typing.NoReturn:
  149. raise RuntimeError("Send channel has not been made available")
  150. class Request(HTTPConnection):
  151. _form: typing.Optional[FormData]
  152. def __init__(
  153. self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
  154. ):
  155. super().__init__(scope)
  156. assert scope["type"] == "http"
  157. self._receive = receive
  158. self._send = send
  159. self._stream_consumed = False
  160. self._is_disconnected = False
  161. self._form = None
  162. @property
  163. def method(self) -> str:
  164. return self.scope["method"]
  165. @property
  166. def receive(self) -> Receive:
  167. return self._receive
  168. async def stream(self) -> typing.AsyncGenerator[bytes, None]:
  169. if hasattr(self, "_body"):
  170. yield self._body
  171. yield b""
  172. return
  173. if self._stream_consumed:
  174. raise RuntimeError("Stream consumed")
  175. self._stream_consumed = True
  176. while True:
  177. message = await self._receive()
  178. if message["type"] == "http.request":
  179. body = message.get("body", b"")
  180. if body:
  181. yield body
  182. if not message.get("more_body", False):
  183. break
  184. elif message["type"] == "http.disconnect":
  185. self._is_disconnected = True
  186. raise ClientDisconnect()
  187. yield b""
  188. async def body(self) -> bytes:
  189. if not hasattr(self, "_body"):
  190. chunks: "typing.List[bytes]" = []
  191. async for chunk in self.stream():
  192. chunks.append(chunk)
  193. self._body = b"".join(chunks)
  194. return self._body
  195. async def json(self) -> typing.Any:
  196. if not hasattr(self, "_json"):
  197. body = await self.body()
  198. self._json = json.loads(body)
  199. return self._json
  200. async def _get_form(
  201. self,
  202. *,
  203. max_files: typing.Union[int, float] = 1000,
  204. max_fields: typing.Union[int, float] = 1000,
  205. ) -> FormData:
  206. if self._form is None:
  207. assert (
  208. parse_options_header is not None
  209. ), "The `python-multipart` library must be installed to use form parsing."
  210. content_type_header = self.headers.get("Content-Type")
  211. content_type: bytes
  212. content_type, _ = parse_options_header(content_type_header)
  213. if content_type == b"multipart/form-data":
  214. try:
  215. multipart_parser = MultiPartParser(
  216. self.headers,
  217. self.stream(),
  218. max_files=max_files,
  219. max_fields=max_fields,
  220. )
  221. self._form = await multipart_parser.parse()
  222. except MultiPartException as exc:
  223. if "app" in self.scope:
  224. raise HTTPException(status_code=400, detail=exc.message)
  225. raise exc
  226. elif content_type == b"application/x-www-form-urlencoded":
  227. form_parser = FormParser(self.headers, self.stream())
  228. self._form = await form_parser.parse()
  229. else:
  230. self._form = FormData()
  231. return self._form
  232. def form(
  233. self,
  234. *,
  235. max_files: typing.Union[int, float] = 1000,
  236. max_fields: typing.Union[int, float] = 1000,
  237. ) -> AwaitableOrContextManager[FormData]:
  238. return AwaitableOrContextManagerWrapper(
  239. self._get_form(max_files=max_files, max_fields=max_fields)
  240. )
  241. async def close(self) -> None:
  242. if self._form is not None:
  243. await self._form.close()
  244. async def is_disconnected(self) -> bool:
  245. if not self._is_disconnected:
  246. message: Message = {}
  247. # If message isn't immediately available, move on
  248. with anyio.CancelScope() as cs:
  249. cs.cancel()
  250. message = await self._receive()
  251. if message.get("type") == "http.disconnect":
  252. self._is_disconnected = True
  253. return self._is_disconnected
  254. async def send_push_promise(self, path: str) -> None:
  255. if "http.response.push" in self.scope.get("extensions", {}):
  256. raw_headers: "typing.List[typing.Tuple[bytes, bytes]]" = []
  257. for name in SERVER_PUSH_HEADERS_TO_COPY:
  258. for value in self.headers.getlist(name):
  259. raw_headers.append(
  260. (name.encode("latin-1"), value.encode("latin-1"))
  261. )
  262. await self._send(
  263. {"type": "http.response.push", "path": path, "headers": raw_headers}
  264. )