about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/starlette/endpoints.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/endpoints.py')
-rw-r--r--.venv/lib/python3.12/site-packages/starlette/endpoints.py122
1 files changed, 122 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/endpoints.py b/.venv/lib/python3.12/site-packages/starlette/endpoints.py
new file mode 100644
index 00000000..10769026
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/starlette/endpoints.py
@@ -0,0 +1,122 @@
+from __future__ import annotations
+
+import json
+import typing
+
+from starlette import status
+from starlette._utils import is_async_callable
+from starlette.concurrency import run_in_threadpool
+from starlette.exceptions import HTTPException
+from starlette.requests import Request
+from starlette.responses import PlainTextResponse, Response
+from starlette.types import Message, Receive, Scope, Send
+from starlette.websockets import WebSocket
+
+
+class HTTPEndpoint:
+    def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        assert scope["type"] == "http"
+        self.scope = scope
+        self.receive = receive
+        self.send = send
+        self._allowed_methods = [
+            method
+            for method in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")
+            if getattr(self, method.lower(), None) is not None
+        ]
+
+    def __await__(self) -> typing.Generator[typing.Any, None, None]:
+        return self.dispatch().__await__()
+
+    async def dispatch(self) -> None:
+        request = Request(self.scope, receive=self.receive)
+        handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower()
+
+        handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed)
+        is_async = is_async_callable(handler)
+        if is_async:
+            response = await handler(request)
+        else:
+            response = await run_in_threadpool(handler, request)
+        await response(self.scope, self.receive, self.send)
+
+    async def method_not_allowed(self, request: Request) -> Response:
+        # If we're running inside a starlette application then raise an
+        # exception, so that the configurable exception handler can deal with
+        # returning the response. For plain ASGI apps, just return the response.
+        headers = {"Allow": ", ".join(self._allowed_methods)}
+        if "app" in self.scope:
+            raise HTTPException(status_code=405, headers=headers)
+        return PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
+
+
+class WebSocketEndpoint:
+    encoding: str | None = None  # May be "text", "bytes", or "json".
+
+    def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
+        assert scope["type"] == "websocket"
+        self.scope = scope
+        self.receive = receive
+        self.send = send
+
+    def __await__(self) -> typing.Generator[typing.Any, None, None]:
+        return self.dispatch().__await__()
+
+    async def dispatch(self) -> None:
+        websocket = WebSocket(self.scope, receive=self.receive, send=self.send)
+        await self.on_connect(websocket)
+
+        close_code = status.WS_1000_NORMAL_CLOSURE
+
+        try:
+            while True:
+                message = await websocket.receive()
+                if message["type"] == "websocket.receive":
+                    data = await self.decode(websocket, message)
+                    await self.on_receive(websocket, data)
+                elif message["type"] == "websocket.disconnect":  # pragma: no branch
+                    close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE)
+                    break
+        except Exception as exc:
+            close_code = status.WS_1011_INTERNAL_ERROR
+            raise exc
+        finally:
+            await self.on_disconnect(websocket, close_code)
+
+    async def decode(self, websocket: WebSocket, message: Message) -> typing.Any:
+        if self.encoding == "text":
+            if "text" not in message:
+                await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
+                raise RuntimeError("Expected text websocket messages, but got bytes")
+            return message["text"]
+
+        elif self.encoding == "bytes":
+            if "bytes" not in message:
+                await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
+                raise RuntimeError("Expected bytes websocket messages, but got text")
+            return message["bytes"]
+
+        elif self.encoding == "json":
+            if message.get("text") is not None:
+                text = message["text"]
+            else:
+                text = message["bytes"].decode("utf-8")
+
+            try:
+                return json.loads(text)
+            except json.decoder.JSONDecodeError:
+                await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA)
+                raise RuntimeError("Malformed JSON data received.")
+
+        assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}"
+        return message["text"] if message.get("text") else message["bytes"]
+
+    async def on_connect(self, websocket: WebSocket) -> None:
+        """Override to handle an incoming websocket connection"""
+        await websocket.accept()
+
+    async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None:
+        """Override to handle an incoming websocket message"""
+
+    async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
+        """Override to handle a disconnecting websocket"""