_qs.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from __future__ import annotations
  2. from typing import Any, List, Tuple, Union, Mapping, TypeVar
  3. from urllib.parse import parse_qs, urlencode
  4. from typing_extensions import Literal, get_args
  5. from ._types import NotGiven, not_given
  6. from ._utils import flatten
  7. _T = TypeVar("_T")
  8. ArrayFormat = Literal["comma", "repeat", "indices", "brackets"]
  9. NestedFormat = Literal["dots", "brackets"]
  10. PrimitiveData = Union[str, int, float, bool, None]
  11. # this should be Data = Union[PrimitiveData, "List[Data]", "Tuple[Data]", "Mapping[str, Data]"]
  12. # https://github.com/microsoft/pyright/issues/3555
  13. Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"]
  14. Params = Mapping[str, Data]
  15. class Querystring:
  16. array_format: ArrayFormat
  17. nested_format: NestedFormat
  18. def __init__(
  19. self,
  20. *,
  21. array_format: ArrayFormat = "repeat",
  22. nested_format: NestedFormat = "brackets",
  23. ) -> None:
  24. self.array_format = array_format
  25. self.nested_format = nested_format
  26. def parse(self, query: str) -> Mapping[str, object]:
  27. # Note: custom format syntax is not supported yet
  28. return parse_qs(query)
  29. def stringify(
  30. self,
  31. params: Params,
  32. *,
  33. array_format: ArrayFormat | NotGiven = not_given,
  34. nested_format: NestedFormat | NotGiven = not_given,
  35. ) -> str:
  36. return urlencode(
  37. self.stringify_items(
  38. params,
  39. array_format=array_format,
  40. nested_format=nested_format,
  41. )
  42. )
  43. def stringify_items(
  44. self,
  45. params: Params,
  46. *,
  47. array_format: ArrayFormat | NotGiven = not_given,
  48. nested_format: NestedFormat | NotGiven = not_given,
  49. ) -> list[tuple[str, str]]:
  50. opts = Options(
  51. qs=self,
  52. array_format=array_format,
  53. nested_format=nested_format,
  54. )
  55. return flatten([self._stringify_item(key, value, opts) for key, value in params.items()])
  56. def _stringify_item(
  57. self,
  58. key: str,
  59. value: Data,
  60. opts: Options,
  61. ) -> list[tuple[str, str]]:
  62. if isinstance(value, Mapping):
  63. items: list[tuple[str, str]] = []
  64. nested_format = opts.nested_format
  65. for subkey, subvalue in value.items():
  66. items.extend(
  67. self._stringify_item(
  68. # TODO: error if unknown format
  69. f"{key}.{subkey}" if nested_format == "dots" else f"{key}[{subkey}]",
  70. subvalue,
  71. opts,
  72. )
  73. )
  74. return items
  75. if isinstance(value, (list, tuple)):
  76. array_format = opts.array_format
  77. if array_format == "comma":
  78. return [
  79. (
  80. key,
  81. ",".join(self._primitive_value_to_str(item) for item in value if item is not None),
  82. ),
  83. ]
  84. elif array_format == "repeat":
  85. items = []
  86. for item in value:
  87. items.extend(self._stringify_item(key, item, opts))
  88. return items
  89. elif array_format == "indices":
  90. raise NotImplementedError("The array indices format is not supported yet")
  91. elif array_format == "brackets":
  92. items = []
  93. key = key + "[]"
  94. for item in value:
  95. items.extend(self._stringify_item(key, item, opts))
  96. return items
  97. else:
  98. raise NotImplementedError(
  99. f"Unknown array_format value: {array_format}, choose from {', '.join(get_args(ArrayFormat))}"
  100. )
  101. serialised = self._primitive_value_to_str(value)
  102. if not serialised:
  103. return []
  104. return [(key, serialised)]
  105. def _primitive_value_to_str(self, value: PrimitiveData) -> str:
  106. # copied from httpx
  107. if value is True:
  108. return "true"
  109. elif value is False:
  110. return "false"
  111. elif value is None:
  112. return ""
  113. return str(value)
  114. _qs = Querystring()
  115. parse = _qs.parse
  116. stringify = _qs.stringify
  117. stringify_items = _qs.stringify_items
  118. class Options:
  119. array_format: ArrayFormat
  120. nested_format: NestedFormat
  121. def __init__(
  122. self,
  123. qs: Querystring = _qs,
  124. *,
  125. array_format: ArrayFormat | NotGiven = not_given,
  126. nested_format: NestedFormat | NotGiven = not_given,
  127. ) -> None:
  128. self.array_format = qs.array_format if isinstance(array_format, NotGiven) else array_format
  129. self.nested_format = qs.nested_format if isinstance(nested_format, NotGiven) else nested_format