datastructures.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708
  1. import typing
  2. from collections.abc import Sequence
  3. from shlex import shlex
  4. from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
  5. from starlette.concurrency import run_in_threadpool
  6. from starlette.types import Scope
  7. class Address(typing.NamedTuple):
  8. host: str
  9. port: int
  10. _KeyType = typing.TypeVar("_KeyType")
  11. # Mapping keys are invariant but their values are covariant since
  12. # you can only read them
  13. # that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
  14. _CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True)
  15. class URL:
  16. def __init__(
  17. self,
  18. url: str = "",
  19. scope: typing.Optional[Scope] = None,
  20. **components: typing.Any,
  21. ) -> None:
  22. if scope is not None:
  23. assert not url, 'Cannot set both "url" and "scope".'
  24. assert not components, 'Cannot set both "scope" and "**components".'
  25. scheme = scope.get("scheme", "http")
  26. server = scope.get("server", None)
  27. path = scope.get("root_path", "") + scope["path"]
  28. query_string = scope.get("query_string", b"")
  29. host_header = None
  30. for key, value in scope["headers"]:
  31. if key == b"host":
  32. host_header = value.decode("latin-1")
  33. break
  34. if host_header is not None:
  35. url = f"{scheme}://{host_header}{path}"
  36. elif server is None:
  37. url = path
  38. else:
  39. host, port = server
  40. default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
  41. if port == default_port:
  42. url = f"{scheme}://{host}{path}"
  43. else:
  44. url = f"{scheme}://{host}:{port}{path}"
  45. if query_string:
  46. url += "?" + query_string.decode()
  47. elif components:
  48. assert not url, 'Cannot set both "url" and "**components".'
  49. url = URL("").replace(**components).components.geturl()
  50. self._url = url
  51. @property
  52. def components(self) -> SplitResult:
  53. if not hasattr(self, "_components"):
  54. self._components = urlsplit(self._url)
  55. return self._components
  56. @property
  57. def scheme(self) -> str:
  58. return self.components.scheme
  59. @property
  60. def netloc(self) -> str:
  61. return self.components.netloc
  62. @property
  63. def path(self) -> str:
  64. return self.components.path
  65. @property
  66. def query(self) -> str:
  67. return self.components.query
  68. @property
  69. def fragment(self) -> str:
  70. return self.components.fragment
  71. @property
  72. def username(self) -> typing.Union[None, str]:
  73. return self.components.username
  74. @property
  75. def password(self) -> typing.Union[None, str]:
  76. return self.components.password
  77. @property
  78. def hostname(self) -> typing.Union[None, str]:
  79. return self.components.hostname
  80. @property
  81. def port(self) -> typing.Optional[int]:
  82. return self.components.port
  83. @property
  84. def is_secure(self) -> bool:
  85. return self.scheme in ("https", "wss")
  86. def replace(self, **kwargs: typing.Any) -> "URL":
  87. if (
  88. "username" in kwargs
  89. or "password" in kwargs
  90. or "hostname" in kwargs
  91. or "port" in kwargs
  92. ):
  93. hostname = kwargs.pop("hostname", None)
  94. port = kwargs.pop("port", self.port)
  95. username = kwargs.pop("username", self.username)
  96. password = kwargs.pop("password", self.password)
  97. if hostname is None:
  98. netloc = self.netloc
  99. _, _, hostname = netloc.rpartition("@")
  100. if hostname[-1] != "]":
  101. hostname = hostname.rsplit(":", 1)[0]
  102. netloc = hostname
  103. if port is not None:
  104. netloc += f":{port}"
  105. if username is not None:
  106. userpass = username
  107. if password is not None:
  108. userpass += f":{password}"
  109. netloc = f"{userpass}@{netloc}"
  110. kwargs["netloc"] = netloc
  111. components = self.components._replace(**kwargs)
  112. return self.__class__(components.geturl())
  113. def include_query_params(self, **kwargs: typing.Any) -> "URL":
  114. params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
  115. params.update({str(key): str(value) for key, value in kwargs.items()})
  116. query = urlencode(params.multi_items())
  117. return self.replace(query=query)
  118. def replace_query_params(self, **kwargs: typing.Any) -> "URL":
  119. query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
  120. return self.replace(query=query)
  121. def remove_query_params(
  122. self, keys: typing.Union[str, typing.Sequence[str]]
  123. ) -> "URL":
  124. if isinstance(keys, str):
  125. keys = [keys]
  126. params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
  127. for key in keys:
  128. params.pop(key, None)
  129. query = urlencode(params.multi_items())
  130. return self.replace(query=query)
  131. def __eq__(self, other: typing.Any) -> bool:
  132. return str(self) == str(other)
  133. def __str__(self) -> str:
  134. return self._url
  135. def __repr__(self) -> str:
  136. url = str(self)
  137. if self.password:
  138. url = str(self.replace(password="********"))
  139. return f"{self.__class__.__name__}({repr(url)})"
  140. class URLPath(str):
  141. """
  142. A URL path string that may also hold an associated protocol and/or host.
  143. Used by the routing to return `url_path_for` matches.
  144. """
  145. def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath":
  146. assert protocol in ("http", "websocket", "")
  147. return str.__new__(cls, path)
  148. def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
  149. self.protocol = protocol
  150. self.host = host
  151. def make_absolute_url(self, base_url: typing.Union[str, URL]) -> URL:
  152. if isinstance(base_url, str):
  153. base_url = URL(base_url)
  154. if self.protocol:
  155. scheme = {
  156. "http": {True: "https", False: "http"},
  157. "websocket": {True: "wss", False: "ws"},
  158. }[self.protocol][base_url.is_secure]
  159. else:
  160. scheme = base_url.scheme
  161. netloc = self.host or base_url.netloc
  162. path = base_url.path.rstrip("/") + str(self)
  163. return URL(scheme=scheme, netloc=netloc, path=path)
  164. class Secret:
  165. """
  166. Holds a string value that should not be revealed in tracebacks etc.
  167. You should cast the value to `str` at the point it is required.
  168. """
  169. def __init__(self, value: str):
  170. self._value = value
  171. def __repr__(self) -> str:
  172. class_name = self.__class__.__name__
  173. return f"{class_name}('**********')"
  174. def __str__(self) -> str:
  175. return self._value
  176. def __bool__(self) -> bool:
  177. return bool(self._value)
  178. class CommaSeparatedStrings(Sequence):
  179. def __init__(self, value: typing.Union[str, typing.Sequence[str]]):
  180. if isinstance(value, str):
  181. splitter = shlex(value, posix=True)
  182. splitter.whitespace = ","
  183. splitter.whitespace_split = True
  184. self._items = [item.strip() for item in splitter]
  185. else:
  186. self._items = list(value)
  187. def __len__(self) -> int:
  188. return len(self._items)
  189. def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any:
  190. return self._items[index]
  191. def __iter__(self) -> typing.Iterator[str]:
  192. return iter(self._items)
  193. def __repr__(self) -> str:
  194. class_name = self.__class__.__name__
  195. items = [item for item in self]
  196. return f"{class_name}({items!r})"
  197. def __str__(self) -> str:
  198. return ", ".join(repr(item) for item in self)
  199. class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
  200. _dict: typing.Dict[_KeyType, _CovariantValueType]
  201. def __init__(
  202. self,
  203. *args: typing.Union[
  204. "ImmutableMultiDict[_KeyType, _CovariantValueType]",
  205. typing.Mapping[_KeyType, _CovariantValueType],
  206. typing.Iterable[typing.Tuple[_KeyType, _CovariantValueType]],
  207. ],
  208. **kwargs: typing.Any,
  209. ) -> None:
  210. assert len(args) < 2, "Too many arguments."
  211. value: typing.Any = args[0] if args else []
  212. if kwargs:
  213. value = (
  214. ImmutableMultiDict(value).multi_items()
  215. + ImmutableMultiDict(kwargs).multi_items() # type: ignore[operator]
  216. )
  217. if not value:
  218. _items: typing.List[typing.Tuple[typing.Any, typing.Any]] = []
  219. elif hasattr(value, "multi_items"):
  220. value = typing.cast(
  221. ImmutableMultiDict[_KeyType, _CovariantValueType], value
  222. )
  223. _items = list(value.multi_items())
  224. elif hasattr(value, "items"):
  225. value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
  226. _items = list(value.items())
  227. else:
  228. value = typing.cast(
  229. typing.List[typing.Tuple[typing.Any, typing.Any]], value
  230. )
  231. _items = list(value)
  232. self._dict = {k: v for k, v in _items}
  233. self._list = _items
  234. def getlist(self, key: typing.Any) -> typing.List[_CovariantValueType]:
  235. return [item_value for item_key, item_value in self._list if item_key == key]
  236. def keys(self) -> typing.KeysView[_KeyType]:
  237. return self._dict.keys()
  238. def values(self) -> typing.ValuesView[_CovariantValueType]:
  239. return self._dict.values()
  240. def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]:
  241. return self._dict.items()
  242. def multi_items(self) -> typing.List[typing.Tuple[_KeyType, _CovariantValueType]]:
  243. return list(self._list)
  244. def __getitem__(self, key: _KeyType) -> _CovariantValueType:
  245. return self._dict[key]
  246. def __contains__(self, key: typing.Any) -> bool:
  247. return key in self._dict
  248. def __iter__(self) -> typing.Iterator[_KeyType]:
  249. return iter(self.keys())
  250. def __len__(self) -> int:
  251. return len(self._dict)
  252. def __eq__(self, other: typing.Any) -> bool:
  253. if not isinstance(other, self.__class__):
  254. return False
  255. return sorted(self._list) == sorted(other._list)
  256. def __repr__(self) -> str:
  257. class_name = self.__class__.__name__
  258. items = self.multi_items()
  259. return f"{class_name}({items!r})"
  260. class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
  261. def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
  262. self.setlist(key, [value])
  263. def __delitem__(self, key: typing.Any) -> None:
  264. self._list = [(k, v) for k, v in self._list if k != key]
  265. del self._dict[key]
  266. def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
  267. self._list = [(k, v) for k, v in self._list if k != key]
  268. return self._dict.pop(key, default)
  269. def popitem(self) -> typing.Tuple:
  270. key, value = self._dict.popitem()
  271. self._list = [(k, v) for k, v in self._list if k != key]
  272. return key, value
  273. def poplist(self, key: typing.Any) -> typing.List:
  274. values = [v for k, v in self._list if k == key]
  275. self.pop(key)
  276. return values
  277. def clear(self) -> None:
  278. self._dict.clear()
  279. self._list.clear()
  280. def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
  281. if key not in self:
  282. self._dict[key] = default
  283. self._list.append((key, default))
  284. return self[key]
  285. def setlist(self, key: typing.Any, values: typing.List) -> None:
  286. if not values:
  287. self.pop(key, None)
  288. else:
  289. existing_items = [(k, v) for (k, v) in self._list if k != key]
  290. self._list = existing_items + [(key, value) for value in values]
  291. self._dict[key] = values[-1]
  292. def append(self, key: typing.Any, value: typing.Any) -> None:
  293. self._list.append((key, value))
  294. self._dict[key] = value
  295. def update(
  296. self,
  297. *args: typing.Union[
  298. "MultiDict",
  299. typing.Mapping,
  300. typing.List[typing.Tuple[typing.Any, typing.Any]],
  301. ],
  302. **kwargs: typing.Any,
  303. ) -> None:
  304. value = MultiDict(*args, **kwargs)
  305. existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
  306. self._list = existing_items + value.multi_items()
  307. self._dict.update(value)
  308. class QueryParams(ImmutableMultiDict[str, str]):
  309. """
  310. An immutable multidict.
  311. """
  312. def __init__(
  313. self,
  314. *args: typing.Union[
  315. "ImmutableMultiDict",
  316. typing.Mapping,
  317. typing.List[typing.Tuple[typing.Any, typing.Any]],
  318. str,
  319. bytes,
  320. ],
  321. **kwargs: typing.Any,
  322. ) -> None:
  323. assert len(args) < 2, "Too many arguments."
  324. value = args[0] if args else []
  325. if isinstance(value, str):
  326. super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
  327. elif isinstance(value, bytes):
  328. super().__init__(
  329. parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs
  330. )
  331. else:
  332. super().__init__(*args, **kwargs) # type: ignore[arg-type]
  333. self._list = [(str(k), str(v)) for k, v in self._list]
  334. self._dict = {str(k): str(v) for k, v in self._dict.items()}
  335. def __str__(self) -> str:
  336. return urlencode(self._list)
  337. def __repr__(self) -> str:
  338. class_name = self.__class__.__name__
  339. query_string = str(self)
  340. return f"{class_name}({query_string!r})"
  341. class UploadFile:
  342. """
  343. An uploaded file included as part of the request data.
  344. """
  345. def __init__(
  346. self,
  347. file: typing.BinaryIO,
  348. *,
  349. size: typing.Optional[int] = None,
  350. filename: typing.Optional[str] = None,
  351. headers: "typing.Optional[Headers]" = None,
  352. ) -> None:
  353. self.filename = filename
  354. self.file = file
  355. self.size = size
  356. self.headers = headers or Headers()
  357. @property
  358. def content_type(self) -> typing.Optional[str]:
  359. return self.headers.get("content-type", None)
  360. @property
  361. def _in_memory(self) -> bool:
  362. # check for SpooledTemporaryFile._rolled
  363. rolled_to_disk = getattr(self.file, "_rolled", True)
  364. return not rolled_to_disk
  365. async def write(self, data: bytes) -> None:
  366. if self.size is not None:
  367. self.size += len(data)
  368. if self._in_memory:
  369. self.file.write(data)
  370. else:
  371. await run_in_threadpool(self.file.write, data)
  372. async def read(self, size: int = -1) -> bytes:
  373. if self._in_memory:
  374. return self.file.read(size)
  375. return await run_in_threadpool(self.file.read, size)
  376. async def seek(self, offset: int) -> None:
  377. if self._in_memory:
  378. self.file.seek(offset)
  379. else:
  380. await run_in_threadpool(self.file.seek, offset)
  381. async def close(self) -> None:
  382. if self._in_memory:
  383. self.file.close()
  384. else:
  385. await run_in_threadpool(self.file.close)
  386. class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
  387. """
  388. An immutable multidict, containing both file uploads and text input.
  389. """
  390. def __init__(
  391. self,
  392. *args: typing.Union[
  393. "FormData",
  394. typing.Mapping[str, typing.Union[str, UploadFile]],
  395. typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]],
  396. ],
  397. **kwargs: typing.Union[str, UploadFile],
  398. ) -> None:
  399. super().__init__(*args, **kwargs)
  400. async def close(self) -> None:
  401. for key, value in self.multi_items():
  402. if isinstance(value, UploadFile):
  403. await value.close()
  404. class Headers(typing.Mapping[str, str]):
  405. """
  406. An immutable, case-insensitive multidict.
  407. """
  408. def __init__(
  409. self,
  410. headers: typing.Optional[typing.Mapping[str, str]] = None,
  411. raw: typing.Optional[typing.List[typing.Tuple[bytes, bytes]]] = None,
  412. scope: typing.Optional[typing.MutableMapping[str, typing.Any]] = None,
  413. ) -> None:
  414. self._list: typing.List[typing.Tuple[bytes, bytes]] = []
  415. if headers is not None:
  416. assert raw is None, 'Cannot set both "headers" and "raw".'
  417. assert scope is None, 'Cannot set both "headers" and "scope".'
  418. self._list = [
  419. (key.lower().encode("latin-1"), value.encode("latin-1"))
  420. for key, value in headers.items()
  421. ]
  422. elif raw is not None:
  423. assert scope is None, 'Cannot set both "raw" and "scope".'
  424. self._list = raw
  425. elif scope is not None:
  426. # scope["headers"] isn't necessarily a list
  427. # it might be a tuple or other iterable
  428. self._list = scope["headers"] = list(scope["headers"])
  429. @property
  430. def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
  431. return list(self._list)
  432. def keys(self) -> typing.List[str]: # type: ignore[override]
  433. return [key.decode("latin-1") for key, value in self._list]
  434. def values(self) -> typing.List[str]: # type: ignore[override]
  435. return [value.decode("latin-1") for key, value in self._list]
  436. def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore[override]
  437. return [
  438. (key.decode("latin-1"), value.decode("latin-1"))
  439. for key, value in self._list
  440. ]
  441. def getlist(self, key: str) -> typing.List[str]:
  442. get_header_key = key.lower().encode("latin-1")
  443. return [
  444. item_value.decode("latin-1")
  445. for item_key, item_value in self._list
  446. if item_key == get_header_key
  447. ]
  448. def mutablecopy(self) -> "MutableHeaders":
  449. return MutableHeaders(raw=self._list[:])
  450. def __getitem__(self, key: str) -> str:
  451. get_header_key = key.lower().encode("latin-1")
  452. for header_key, header_value in self._list:
  453. if header_key == get_header_key:
  454. return header_value.decode("latin-1")
  455. raise KeyError(key)
  456. def __contains__(self, key: typing.Any) -> bool:
  457. get_header_key = key.lower().encode("latin-1")
  458. for header_key, header_value in self._list:
  459. if header_key == get_header_key:
  460. return True
  461. return False
  462. def __iter__(self) -> typing.Iterator[typing.Any]:
  463. return iter(self.keys())
  464. def __len__(self) -> int:
  465. return len(self._list)
  466. def __eq__(self, other: typing.Any) -> bool:
  467. if not isinstance(other, Headers):
  468. return False
  469. return sorted(self._list) == sorted(other._list)
  470. def __repr__(self) -> str:
  471. class_name = self.__class__.__name__
  472. as_dict = dict(self.items())
  473. if len(as_dict) == len(self):
  474. return f"{class_name}({as_dict!r})"
  475. return f"{class_name}(raw={self.raw!r})"
  476. class MutableHeaders(Headers):
  477. def __setitem__(self, key: str, value: str) -> None:
  478. """
  479. Set the header `key` to `value`, removing any duplicate entries.
  480. Retains insertion order.
  481. """
  482. set_key = key.lower().encode("latin-1")
  483. set_value = value.encode("latin-1")
  484. found_indexes: "typing.List[int]" = []
  485. for idx, (item_key, item_value) in enumerate(self._list):
  486. if item_key == set_key:
  487. found_indexes.append(idx)
  488. for idx in reversed(found_indexes[1:]):
  489. del self._list[idx]
  490. if found_indexes:
  491. idx = found_indexes[0]
  492. self._list[idx] = (set_key, set_value)
  493. else:
  494. self._list.append((set_key, set_value))
  495. def __delitem__(self, key: str) -> None:
  496. """
  497. Remove the header `key`.
  498. """
  499. del_key = key.lower().encode("latin-1")
  500. pop_indexes: "typing.List[int]" = []
  501. for idx, (item_key, item_value) in enumerate(self._list):
  502. if item_key == del_key:
  503. pop_indexes.append(idx)
  504. for idx in reversed(pop_indexes):
  505. del self._list[idx]
  506. def __ior__(self, other: typing.Mapping[str, str]) -> "MutableHeaders":
  507. if not isinstance(other, typing.Mapping):
  508. raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
  509. self.update(other)
  510. return self
  511. def __or__(self, other: typing.Mapping[str, str]) -> "MutableHeaders":
  512. if not isinstance(other, typing.Mapping):
  513. raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
  514. new = self.mutablecopy()
  515. new.update(other)
  516. return new
  517. @property
  518. def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]:
  519. return self._list
  520. def setdefault(self, key: str, value: str) -> str:
  521. """
  522. If the header `key` does not exist, then set it to `value`.
  523. Returns the header value.
  524. """
  525. set_key = key.lower().encode("latin-1")
  526. set_value = value.encode("latin-1")
  527. for idx, (item_key, item_value) in enumerate(self._list):
  528. if item_key == set_key:
  529. return item_value.decode("latin-1")
  530. self._list.append((set_key, set_value))
  531. return value
  532. def update(self, other: typing.Mapping[str, str]) -> None:
  533. for key, val in other.items():
  534. self[key] = val
  535. def append(self, key: str, value: str) -> None:
  536. """
  537. Append a header, preserving any duplicate entries.
  538. """
  539. append_key = key.lower().encode("latin-1")
  540. append_value = value.encode("latin-1")
  541. self._list.append((append_key, append_value))
  542. def add_vary_header(self, vary: str) -> None:
  543. existing = self.get("vary")
  544. if existing is not None:
  545. vary = ", ".join([existing, vary])
  546. self["vary"] = vary
  547. class State:
  548. """
  549. An object that can be used to store arbitrary state.
  550. Used for `request.state` and `app.state`.
  551. """
  552. _state: typing.Dict[str, typing.Any]
  553. def __init__(self, state: typing.Optional[typing.Dict[str, typing.Any]] = None):
  554. if state is None:
  555. state = {}
  556. super().__setattr__("_state", state)
  557. def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
  558. self._state[key] = value
  559. def __getattr__(self, key: typing.Any) -> typing.Any:
  560. try:
  561. return self._state[key]
  562. except KeyError:
  563. message = "'{}' object has no attribute '{}'"
  564. raise AttributeError(message.format(self.__class__.__name__, key))
  565. def __delattr__(self, key: typing.Any) -> None:
  566. del self._state[key]