formparsers.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. import typing
  2. from dataclasses import dataclass, field
  3. from enum import Enum
  4. from tempfile import SpooledTemporaryFile
  5. from urllib.parse import unquote_plus
  6. from starlette.datastructures import FormData, Headers, UploadFile
  7. try:
  8. import multipart
  9. from multipart.multipart import parse_options_header
  10. except ModuleNotFoundError: # pragma: nocover
  11. parse_options_header = None
  12. multipart = None
  13. class FormMessage(Enum):
  14. FIELD_START = 1
  15. FIELD_NAME = 2
  16. FIELD_DATA = 3
  17. FIELD_END = 4
  18. END = 5
  19. @dataclass
  20. class MultipartPart:
  21. content_disposition: typing.Optional[bytes] = None
  22. field_name: str = ""
  23. data: bytes = b""
  24. file: typing.Optional[UploadFile] = None
  25. item_headers: typing.List[typing.Tuple[bytes, bytes]] = field(default_factory=list)
  26. def _user_safe_decode(src: bytes, codec: str) -> str:
  27. try:
  28. return src.decode(codec)
  29. except (UnicodeDecodeError, LookupError):
  30. return src.decode("latin-1")
  31. class MultiPartException(Exception):
  32. def __init__(self, message: str) -> None:
  33. self.message = message
  34. class FormParser:
  35. def __init__(
  36. self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
  37. ) -> None:
  38. assert (
  39. multipart is not None
  40. ), "The `python-multipart` library must be installed to use form parsing."
  41. self.headers = headers
  42. self.stream = stream
  43. self.messages: typing.List[typing.Tuple[FormMessage, bytes]] = []
  44. def on_field_start(self) -> None:
  45. message = (FormMessage.FIELD_START, b"")
  46. self.messages.append(message)
  47. def on_field_name(self, data: bytes, start: int, end: int) -> None:
  48. message = (FormMessage.FIELD_NAME, data[start:end])
  49. self.messages.append(message)
  50. def on_field_data(self, data: bytes, start: int, end: int) -> None:
  51. message = (FormMessage.FIELD_DATA, data[start:end])
  52. self.messages.append(message)
  53. def on_field_end(self) -> None:
  54. message = (FormMessage.FIELD_END, b"")
  55. self.messages.append(message)
  56. def on_end(self) -> None:
  57. message = (FormMessage.END, b"")
  58. self.messages.append(message)
  59. async def parse(self) -> FormData:
  60. # Callbacks dictionary.
  61. callbacks = {
  62. "on_field_start": self.on_field_start,
  63. "on_field_name": self.on_field_name,
  64. "on_field_data": self.on_field_data,
  65. "on_field_end": self.on_field_end,
  66. "on_end": self.on_end,
  67. }
  68. # Create the parser.
  69. parser = multipart.QuerystringParser(callbacks)
  70. field_name = b""
  71. field_value = b""
  72. items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
  73. # Feed the parser with data from the request.
  74. async for chunk in self.stream:
  75. if chunk:
  76. parser.write(chunk)
  77. else:
  78. parser.finalize()
  79. messages = list(self.messages)
  80. self.messages.clear()
  81. for message_type, message_bytes in messages:
  82. if message_type == FormMessage.FIELD_START:
  83. field_name = b""
  84. field_value = b""
  85. elif message_type == FormMessage.FIELD_NAME:
  86. field_name += message_bytes
  87. elif message_type == FormMessage.FIELD_DATA:
  88. field_value += message_bytes
  89. elif message_type == FormMessage.FIELD_END:
  90. name = unquote_plus(field_name.decode("latin-1"))
  91. value = unquote_plus(field_value.decode("latin-1"))
  92. items.append((name, value))
  93. return FormData(items)
  94. class MultiPartParser:
  95. max_file_size = 1024 * 1024
  96. def __init__(
  97. self,
  98. headers: Headers,
  99. stream: typing.AsyncGenerator[bytes, None],
  100. *,
  101. max_files: typing.Union[int, float] = 1000,
  102. max_fields: typing.Union[int, float] = 1000,
  103. ) -> None:
  104. assert (
  105. multipart is not None
  106. ), "The `python-multipart` library must be installed to use form parsing."
  107. self.headers = headers
  108. self.stream = stream
  109. self.max_files = max_files
  110. self.max_fields = max_fields
  111. self.items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
  112. self._current_files = 0
  113. self._current_fields = 0
  114. self._current_partial_header_name: bytes = b""
  115. self._current_partial_header_value: bytes = b""
  116. self._current_part = MultipartPart()
  117. self._charset = ""
  118. self._file_parts_to_write: typing.List[typing.Tuple[MultipartPart, bytes]] = []
  119. self._file_parts_to_finish: typing.List[MultipartPart] = []
  120. self._files_to_close_on_error: typing.List[SpooledTemporaryFile] = []
  121. def on_part_begin(self) -> None:
  122. self._current_part = MultipartPart()
  123. def on_part_data(self, data: bytes, start: int, end: int) -> None:
  124. message_bytes = data[start:end]
  125. if self._current_part.file is None:
  126. self._current_part.data += message_bytes
  127. else:
  128. self._file_parts_to_write.append((self._current_part, message_bytes))
  129. def on_part_end(self) -> None:
  130. if self._current_part.file is None:
  131. self.items.append(
  132. (
  133. self._current_part.field_name,
  134. _user_safe_decode(self._current_part.data, self._charset),
  135. )
  136. )
  137. else:
  138. self._file_parts_to_finish.append(self._current_part)
  139. # The file can be added to the items right now even though it's not
  140. # finished yet, because it will be finished in the `parse()` method, before
  141. # self.items is used in the return value.
  142. self.items.append((self._current_part.field_name, self._current_part.file))
  143. def on_header_field(self, data: bytes, start: int, end: int) -> None:
  144. self._current_partial_header_name += data[start:end]
  145. def on_header_value(self, data: bytes, start: int, end: int) -> None:
  146. self._current_partial_header_value += data[start:end]
  147. def on_header_end(self) -> None:
  148. field = self._current_partial_header_name.lower()
  149. if field == b"content-disposition":
  150. self._current_part.content_disposition = self._current_partial_header_value
  151. self._current_part.item_headers.append(
  152. (field, self._current_partial_header_value)
  153. )
  154. self._current_partial_header_name = b""
  155. self._current_partial_header_value = b""
  156. def on_headers_finished(self) -> None:
  157. disposition, options = parse_options_header(
  158. self._current_part.content_disposition
  159. )
  160. try:
  161. self._current_part.field_name = _user_safe_decode(
  162. options[b"name"], self._charset
  163. )
  164. except KeyError:
  165. raise MultiPartException(
  166. 'The Content-Disposition header field "name" must be ' "provided."
  167. )
  168. if b"filename" in options:
  169. self._current_files += 1
  170. if self._current_files > self.max_files:
  171. raise MultiPartException(
  172. f"Too many files. Maximum number of files is {self.max_files}."
  173. )
  174. filename = _user_safe_decode(options[b"filename"], self._charset)
  175. tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
  176. self._files_to_close_on_error.append(tempfile)
  177. self._current_part.file = UploadFile(
  178. file=tempfile, # type: ignore[arg-type]
  179. size=0,
  180. filename=filename,
  181. headers=Headers(raw=self._current_part.item_headers),
  182. )
  183. else:
  184. self._current_fields += 1
  185. if self._current_fields > self.max_fields:
  186. raise MultiPartException(
  187. f"Too many fields. Maximum number of fields is {self.max_fields}."
  188. )
  189. self._current_part.file = None
  190. def on_end(self) -> None:
  191. pass
  192. async def parse(self) -> FormData:
  193. # Parse the Content-Type header to get the multipart boundary.
  194. _, params = parse_options_header(self.headers["Content-Type"])
  195. charset = params.get(b"charset", "utf-8")
  196. if type(charset) == bytes:
  197. charset = charset.decode("latin-1")
  198. self._charset = charset
  199. try:
  200. boundary = params[b"boundary"]
  201. except KeyError:
  202. raise MultiPartException("Missing boundary in multipart.")
  203. # Callbacks dictionary.
  204. callbacks = {
  205. "on_part_begin": self.on_part_begin,
  206. "on_part_data": self.on_part_data,
  207. "on_part_end": self.on_part_end,
  208. "on_header_field": self.on_header_field,
  209. "on_header_value": self.on_header_value,
  210. "on_header_end": self.on_header_end,
  211. "on_headers_finished": self.on_headers_finished,
  212. "on_end": self.on_end,
  213. }
  214. # Create the parser.
  215. parser = multipart.MultipartParser(boundary, callbacks)
  216. try:
  217. # Feed the parser with data from the request.
  218. async for chunk in self.stream:
  219. parser.write(chunk)
  220. # Write file data, it needs to use await with the UploadFile methods
  221. # that call the corresponding file methods *in a threadpool*,
  222. # otherwise, if they were called directly in the callback methods above
  223. # (regular, non-async functions), that would block the event loop in
  224. # the main thread.
  225. for part, data in self._file_parts_to_write:
  226. assert part.file # for type checkers
  227. await part.file.write(data)
  228. for part in self._file_parts_to_finish:
  229. assert part.file # for type checkers
  230. await part.file.seek(0)
  231. self._file_parts_to_write.clear()
  232. self._file_parts_to_finish.clear()
  233. except MultiPartException as exc:
  234. # Close all the files if there was an error.
  235. for file in self._files_to_close_on_error:
  236. file.close()
  237. raise exc
  238. parser.finalize()
  239. return FormData(self.items)