diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/middleware/wsgi.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/starlette/middleware/wsgi.py | 152 |
1 files changed, 152 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/wsgi.py b/.venv/lib/python3.12/site-packages/starlette/middleware/wsgi.py new file mode 100644 index 00000000..6e0a3fae --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/wsgi.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import io +import math +import sys +import typing +import warnings + +import anyio +from anyio.abc import ObjectReceiveStream, ObjectSendStream + +from starlette.types import Receive, Scope, Send + +warnings.warn( + "starlette.middleware.wsgi is deprecated and will be removed in a future release. " + "Please refer to https://github.com/abersheeran/a2wsgi as a replacement.", + DeprecationWarning, +) + + +def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]: + """ + Builds a scope and request body into a WSGI environ object. + """ + + script_name = scope.get("root_path", "").encode("utf8").decode("latin1") + path_info = scope["path"].encode("utf8").decode("latin1") + if path_info.startswith(script_name): + path_info = path_info[len(script_name) :] + + environ = { + "REQUEST_METHOD": scope["method"], + "SCRIPT_NAME": script_name, + "PATH_INFO": path_info, + "QUERY_STRING": scope["query_string"].decode("ascii"), + "SERVER_PROTOCOL": f"HTTP/{scope['http_version']}", + "wsgi.version": (1, 0), + "wsgi.url_scheme": scope.get("scheme", "http"), + "wsgi.input": io.BytesIO(body), + "wsgi.errors": sys.stdout, + "wsgi.multithread": True, + "wsgi.multiprocess": True, + "wsgi.run_once": False, + } + + # Get server name and port - required in WSGI, not in ASGI + server = scope.get("server") or ("localhost", 80) + environ["SERVER_NAME"] = server[0] + environ["SERVER_PORT"] = server[1] + + # Get client IP address + if scope.get("client"): + environ["REMOTE_ADDR"] = scope["client"][0] + + # Go through headers and make them into environ entries + for name, value in scope.get("headers", []): + name = name.decode("latin1") + if name == "content-length": + corrected_name = "CONTENT_LENGTH" + elif name == "content-type": + corrected_name = "CONTENT_TYPE" + else: + corrected_name = f"HTTP_{name}".upper().replace("-", "_") + # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in + # case + value = value.decode("latin1") + if corrected_name in environ: + value = environ[corrected_name] + "," + value + environ[corrected_name] = value + return environ + + +class WSGIMiddleware: + def __init__(self, app: typing.Callable[..., typing.Any]) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + responder = WSGIResponder(self.app, scope) + await responder(receive, send) + + +class WSGIResponder: + stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]] + stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] + + def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None: + self.app = app + self.scope = scope + self.status = None + self.response_headers = None + self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf) + self.response_started = False + self.exc_info: typing.Any = None + + async def __call__(self, receive: Receive, send: Send) -> None: + body = b"" + more_body = True + while more_body: + message = await receive() + body += message.get("body", b"") + more_body = message.get("more_body", False) + environ = build_environ(self.scope, body) + + async with anyio.create_task_group() as task_group: + task_group.start_soon(self.sender, send) + async with self.stream_send: + await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) + if self.exc_info is not None: + raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) + + async def sender(self, send: Send) -> None: + async with self.stream_receive: + async for message in self.stream_receive: + await send(message) + + def start_response( + self, + status: str, + response_headers: list[tuple[str, str]], + exc_info: typing.Any = None, + ) -> None: + self.exc_info = exc_info + if not self.response_started: # pragma: no branch + self.response_started = True + status_code_string, _ = status.split(" ", 1) + status_code = int(status_code_string) + headers = [ + (name.strip().encode("ascii").lower(), value.strip().encode("ascii")) + for name, value in response_headers + ] + anyio.from_thread.run( + self.stream_send.send, + { + "type": "http.response.start", + "status": status_code, + "headers": headers, + }, + ) + + def wsgi( + self, + environ: dict[str, typing.Any], + start_response: typing.Callable[..., typing.Any], + ) -> None: + for chunk in self.app(environ, start_response): + anyio.from_thread.run( + self.stream_send.send, + {"type": "http.response.body", "body": chunk, "more_body": True}, + ) + + anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""}) |