diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/starlette/endpoints.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/endpoints.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/starlette/endpoints.py | 122 |
1 files changed, 122 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/endpoints.py b/.venv/lib/python3.12/site-packages/starlette/endpoints.py new file mode 100644 index 00000000..10769026 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/endpoints.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import json +import typing + +from starlette import status +from starlette._utils import is_async_callable +from starlette.concurrency import run_in_threadpool +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response +from starlette.types import Message, Receive, Scope, Send +from starlette.websockets import WebSocket + + +class HTTPEndpoint: + def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + self.scope = scope + self.receive = receive + self.send = send + self._allowed_methods = [ + method + for method in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS") + if getattr(self, method.lower(), None) is not None + ] + + def __await__(self) -> typing.Generator[typing.Any, None, None]: + return self.dispatch().__await__() + + async def dispatch(self) -> None: + request = Request(self.scope, receive=self.receive) + handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower() + + handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed) + is_async = is_async_callable(handler) + if is_async: + response = await handler(request) + else: + response = await run_in_threadpool(handler, request) + await response(self.scope, self.receive, self.send) + + async def method_not_allowed(self, request: Request) -> Response: + # If we're running inside a starlette application then raise an + # exception, so that the configurable exception handler can deal with + # returning the response. For plain ASGI apps, just return the response. + headers = {"Allow": ", ".join(self._allowed_methods)} + if "app" in self.scope: + raise HTTPException(status_code=405, headers=headers) + return PlainTextResponse("Method Not Allowed", status_code=405, headers=headers) + + +class WebSocketEndpoint: + encoding: str | None = None # May be "text", "bytes", or "json". + + def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "websocket" + self.scope = scope + self.receive = receive + self.send = send + + def __await__(self) -> typing.Generator[typing.Any, None, None]: + return self.dispatch().__await__() + + async def dispatch(self) -> None: + websocket = WebSocket(self.scope, receive=self.receive, send=self.send) + await self.on_connect(websocket) + + close_code = status.WS_1000_NORMAL_CLOSURE + + try: + while True: + message = await websocket.receive() + if message["type"] == "websocket.receive": + data = await self.decode(websocket, message) + await self.on_receive(websocket, data) + elif message["type"] == "websocket.disconnect": # pragma: no branch + close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE) + break + except Exception as exc: + close_code = status.WS_1011_INTERNAL_ERROR + raise exc + finally: + await self.on_disconnect(websocket, close_code) + + async def decode(self, websocket: WebSocket, message: Message) -> typing.Any: + if self.encoding == "text": + if "text" not in message: + await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) + raise RuntimeError("Expected text websocket messages, but got bytes") + return message["text"] + + elif self.encoding == "bytes": + if "bytes" not in message: + await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) + raise RuntimeError("Expected bytes websocket messages, but got text") + return message["bytes"] + + elif self.encoding == "json": + if message.get("text") is not None: + text = message["text"] + else: + text = message["bytes"].decode("utf-8") + + try: + return json.loads(text) + except json.decoder.JSONDecodeError: + await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) + raise RuntimeError("Malformed JSON data received.") + + assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}" + return message["text"] if message.get("text") else message["bytes"] + + async def on_connect(self, websocket: WebSocket) -> None: + """Override to handle an incoming websocket connection""" + await websocket.accept() + + async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None: + """Override to handle an incoming websocket message""" + + async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: + """Override to handle a disconnecting websocket""" |