wsgi.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import io
  2. import math
  3. import sys
  4. import typing
  5. import warnings
  6. import anyio
  7. from starlette.types import Receive, Scope, Send
  8. warnings.warn(
  9. "starlette.middleware.wsgi is deprecated and will be removed in a future release. "
  10. "Please refer to https://github.com/abersheeran/a2wsgi as a replacement.",
  11. DeprecationWarning,
  12. )
  13. def build_environ(scope: Scope, body: bytes) -> dict:
  14. """
  15. Builds a scope and request body into a WSGI environ object.
  16. """
  17. environ = {
  18. "REQUEST_METHOD": scope["method"],
  19. "SCRIPT_NAME": scope.get("root_path", "").encode("utf8").decode("latin1"),
  20. "PATH_INFO": scope["path"].encode("utf8").decode("latin1"),
  21. "QUERY_STRING": scope["query_string"].decode("ascii"),
  22. "SERVER_PROTOCOL": f"HTTP/{scope['http_version']}",
  23. "wsgi.version": (1, 0),
  24. "wsgi.url_scheme": scope.get("scheme", "http"),
  25. "wsgi.input": io.BytesIO(body),
  26. "wsgi.errors": sys.stdout,
  27. "wsgi.multithread": True,
  28. "wsgi.multiprocess": True,
  29. "wsgi.run_once": False,
  30. }
  31. # Get server name and port - required in WSGI, not in ASGI
  32. server = scope.get("server") or ("localhost", 80)
  33. environ["SERVER_NAME"] = server[0]
  34. environ["SERVER_PORT"] = server[1]
  35. # Get client IP address
  36. if scope.get("client"):
  37. environ["REMOTE_ADDR"] = scope["client"][0]
  38. # Go through headers and make them into environ entries
  39. for name, value in scope.get("headers", []):
  40. name = name.decode("latin1")
  41. if name == "content-length":
  42. corrected_name = "CONTENT_LENGTH"
  43. elif name == "content-type":
  44. corrected_name = "CONTENT_TYPE"
  45. else:
  46. corrected_name = f"HTTP_{name}".upper().replace("-", "_")
  47. # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in
  48. # case
  49. value = value.decode("latin1")
  50. if corrected_name in environ:
  51. value = environ[corrected_name] + "," + value
  52. environ[corrected_name] = value
  53. return environ
  54. class WSGIMiddleware:
  55. def __init__(self, app: typing.Callable) -> None:
  56. self.app = app
  57. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  58. assert scope["type"] == "http"
  59. responder = WSGIResponder(self.app, scope)
  60. await responder(receive, send)
  61. class WSGIResponder:
  62. def __init__(self, app: typing.Callable, scope: Scope) -> None:
  63. self.app = app
  64. self.scope = scope
  65. self.status = None
  66. self.response_headers = None
  67. self.stream_send, self.stream_receive = anyio.create_memory_object_stream(
  68. math.inf
  69. )
  70. self.response_started = False
  71. self.exc_info: typing.Any = None
  72. async def __call__(self, receive: Receive, send: Send) -> None:
  73. body = b""
  74. more_body = True
  75. while more_body:
  76. message = await receive()
  77. body += message.get("body", b"")
  78. more_body = message.get("more_body", False)
  79. environ = build_environ(self.scope, body)
  80. async with anyio.create_task_group() as task_group:
  81. task_group.start_soon(self.sender, send)
  82. async with self.stream_send:
  83. await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response)
  84. if self.exc_info is not None:
  85. raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
  86. async def sender(self, send: Send) -> None:
  87. async with self.stream_receive:
  88. async for message in self.stream_receive:
  89. await send(message)
  90. def start_response(
  91. self,
  92. status: str,
  93. response_headers: typing.List[typing.Tuple[str, str]],
  94. exc_info: typing.Any = None,
  95. ) -> None:
  96. self.exc_info = exc_info
  97. if not self.response_started:
  98. self.response_started = True
  99. status_code_string, _ = status.split(" ", 1)
  100. status_code = int(status_code_string)
  101. headers = [
  102. (name.strip().encode("ascii").lower(), value.strip().encode("ascii"))
  103. for name, value in response_headers
  104. ]
  105. anyio.from_thread.run(
  106. self.stream_send.send,
  107. {
  108. "type": "http.response.start",
  109. "status": status_code,
  110. "headers": headers,
  111. },
  112. )
  113. def wsgi(self, environ: dict, start_response: typing.Callable) -> None:
  114. for chunk in self.app(environ, start_response):
  115. anyio.from_thread.run(
  116. self.stream_send.send,
  117. {"type": "http.response.body", "body": chunk, "more_body": True},
  118. )
  119. anyio.from_thread.run(
  120. self.stream_send.send, {"type": "http.response.body", "body": b""}
  121. )