gzip.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import gzip
  2. import io
  3. import typing
  4. from starlette.datastructures import Headers, MutableHeaders
  5. from starlette.types import ASGIApp, Message, Receive, Scope, Send
  6. class GZipMiddleware:
  7. def __init__(
  8. self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
  9. ) -> None:
  10. self.app = app
  11. self.minimum_size = minimum_size
  12. self.compresslevel = compresslevel
  13. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  14. if scope["type"] == "http":
  15. headers = Headers(scope=scope)
  16. if "gzip" in headers.get("Accept-Encoding", ""):
  17. responder = GZipResponder(
  18. self.app, self.minimum_size, compresslevel=self.compresslevel
  19. )
  20. await responder(scope, receive, send)
  21. return
  22. await self.app(scope, receive, send)
  23. class GZipResponder:
  24. def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
  25. self.app = app
  26. self.minimum_size = minimum_size
  27. self.send: Send = unattached_send
  28. self.initial_message: Message = {}
  29. self.started = False
  30. self.content_encoding_set = False
  31. self.gzip_buffer = io.BytesIO()
  32. self.gzip_file = gzip.GzipFile(
  33. mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel
  34. )
  35. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  36. self.send = send
  37. await self.app(scope, receive, self.send_with_gzip)
  38. async def send_with_gzip(self, message: Message) -> None:
  39. message_type = message["type"]
  40. if message_type == "http.response.start":
  41. # Don't send the initial message until we've determined how to
  42. # modify the outgoing headers correctly.
  43. self.initial_message = message
  44. headers = Headers(raw=self.initial_message["headers"])
  45. self.content_encoding_set = "content-encoding" in headers
  46. elif message_type == "http.response.body" and self.content_encoding_set:
  47. if not self.started:
  48. self.started = True
  49. await self.send(self.initial_message)
  50. await self.send(message)
  51. elif message_type == "http.response.body" and not self.started:
  52. self.started = True
  53. body = message.get("body", b"")
  54. more_body = message.get("more_body", False)
  55. if len(body) < self.minimum_size and not more_body:
  56. # Don't apply GZip to small outgoing responses.
  57. await self.send(self.initial_message)
  58. await self.send(message)
  59. elif not more_body:
  60. # Standard GZip response.
  61. self.gzip_file.write(body)
  62. self.gzip_file.close()
  63. body = self.gzip_buffer.getvalue()
  64. headers = MutableHeaders(raw=self.initial_message["headers"])
  65. headers["Content-Encoding"] = "gzip"
  66. headers["Content-Length"] = str(len(body))
  67. headers.add_vary_header("Accept-Encoding")
  68. message["body"] = body
  69. await self.send(self.initial_message)
  70. await self.send(message)
  71. else:
  72. # Initial body in streaming GZip response.
  73. headers = MutableHeaders(raw=self.initial_message["headers"])
  74. headers["Content-Encoding"] = "gzip"
  75. headers.add_vary_header("Accept-Encoding")
  76. del headers["Content-Length"]
  77. self.gzip_file.write(body)
  78. message["body"] = self.gzip_buffer.getvalue()
  79. self.gzip_buffer.seek(0)
  80. self.gzip_buffer.truncate()
  81. await self.send(self.initial_message)
  82. await self.send(message)
  83. elif message_type == "http.response.body":
  84. # Remaining body in streaming GZip response.
  85. body = message.get("body", b"")
  86. more_body = message.get("more_body", False)
  87. self.gzip_file.write(body)
  88. if not more_body:
  89. self.gzip_file.close()
  90. message["body"] = self.gzip_buffer.getvalue()
  91. self.gzip_buffer.seek(0)
  92. self.gzip_buffer.truncate()
  93. await self.send(message)
  94. async def unattached_send(message: Message) -> typing.NoReturn:
  95. raise RuntimeError("send awaitable not set") # pragma: no cover