diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/starlette | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette')
35 files changed, 6631 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/__init__.py b/.venv/lib/python3.12/site-packages/starlette/__init__.py new file mode 100644 index 00000000..cffb82c5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/__init__.py @@ -0,0 +1 @@ +__version__ = "0.46.1" diff --git a/.venv/lib/python3.12/site-packages/starlette/_exception_handler.py b/.venv/lib/python3.12/site-packages/starlette/_exception_handler.py new file mode 100644 index 00000000..72bc89d9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/_exception_handler.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import typing + +from starlette._utils import is_async_callable +from starlette.concurrency import run_in_threadpool +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send +from starlette.websockets import WebSocket + +ExceptionHandlers = dict[typing.Any, ExceptionHandler] +StatusHandlers = dict[int, ExceptionHandler] + + +def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None: + for cls in type(exc).__mro__: + if cls in exc_handlers: + return exc_handlers[cls] + return None + + +def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp: + exception_handlers: ExceptionHandlers + status_handlers: StatusHandlers + try: + exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"] + except KeyError: + exception_handlers, status_handlers = {}, {} + + async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None: + response_started = False + + async def sender(message: Message) -> None: + nonlocal response_started + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + try: + await app(scope, receive, sender) + except Exception as exc: + handler = None + + if isinstance(exc, HTTPException): + handler = status_handlers.get(exc.status_code) + + if handler is None: + handler = _lookup_exception_handler(exception_handlers, exc) + + if handler is None: + raise exc + + if response_started: + raise RuntimeError("Caught handled exception, but response already started.") from exc + + if is_async_callable(handler): + response = await handler(conn, exc) + else: + response = await run_in_threadpool(handler, conn, exc) # type: ignore + if response is not None: + await response(scope, receive, sender) + + return wrapped_app diff --git a/.venv/lib/python3.12/site-packages/starlette/_utils.py b/.venv/lib/python3.12/site-packages/starlette/_utils.py new file mode 100644 index 00000000..8001c472 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/_utils.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import functools +import inspect +import sys +import typing +from contextlib import contextmanager + +from starlette.types import Scope + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import TypeGuard +else: # pragma: no cover + from typing_extensions import TypeGuard + +has_exceptiongroups = True +if sys.version_info < (3, 11): # pragma: no cover + try: + from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found] + except ImportError: + has_exceptiongroups = False + +T = typing.TypeVar("T") +AwaitableCallable = typing.Callable[..., typing.Awaitable[T]] + + +@typing.overload +def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ... + + +@typing.overload +def is_async_callable(obj: typing.Any) -> TypeGuard[AwaitableCallable[typing.Any]]: ... + + +def is_async_callable(obj: typing.Any) -> typing.Any: + while isinstance(obj, functools.partial): + obj = obj.func + + return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) + + +T_co = typing.TypeVar("T_co", covariant=True) + + +class AwaitableOrContextManager(typing.Awaitable[T_co], typing.AsyncContextManager[T_co], typing.Protocol[T_co]): ... + + +class SupportsAsyncClose(typing.Protocol): + async def close(self) -> None: ... # pragma: no cover + + +SupportsAsyncCloseType = typing.TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False) + + +class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]): + __slots__ = ("aw", "entered") + + def __init__(self, aw: typing.Awaitable[SupportsAsyncCloseType]) -> None: + self.aw = aw + + def __await__(self) -> typing.Generator[typing.Any, None, SupportsAsyncCloseType]: + return self.aw.__await__() + + async def __aenter__(self) -> SupportsAsyncCloseType: + self.entered = await self.aw + return self.entered + + async def __aexit__(self, *args: typing.Any) -> None | bool: + await self.entered.close() + return None + + +@contextmanager +def collapse_excgroups() -> typing.Generator[None, None, None]: + try: + yield + except BaseException as exc: + if has_exceptiongroups: # pragma: no cover + while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: + exc = exc.exceptions[0] + + raise exc + + +def get_route_path(scope: Scope) -> str: + path: str = scope["path"] + root_path = scope.get("root_path", "") + if not root_path: + return path + + if not path.startswith(root_path): + return path + + if path == root_path: + return "" + + if path[len(root_path)] == "/": + return path[len(root_path) :] + + return path diff --git a/.venv/lib/python3.12/site-packages/starlette/applications.py b/.venv/lib/python3.12/site-packages/starlette/applications.py new file mode 100644 index 00000000..6df5a707 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/applications.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import sys +import typing +import warnings + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + +from starlette.datastructures import State, URLPath +from starlette.middleware import Middleware, _MiddlewareFactory +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.errors import ServerErrorMiddleware +from starlette.middleware.exceptions import ExceptionMiddleware +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import BaseRoute, Router +from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send +from starlette.websockets import WebSocket + +AppType = typing.TypeVar("AppType", bound="Starlette") +P = ParamSpec("P") + + +class Starlette: + """Creates an Starlette application.""" + + def __init__( + self: AppType, + debug: bool = False, + routes: typing.Sequence[BaseRoute] | None = None, + middleware: typing.Sequence[Middleware] | None = None, + exception_handlers: typing.Mapping[typing.Any, ExceptionHandler] | None = None, + on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, + on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, + lifespan: Lifespan[AppType] | None = None, + ) -> None: + """Initializes the application. + + Parameters: + debug: Boolean indicating if debug tracebacks should be returned on errors. + routes: A list of routes to serve incoming HTTP and WebSocket requests. + middleware: A list of middleware to run for every request. A starlette + application will always automatically include two middleware classes. + `ServerErrorMiddleware` is added as the very outermost middleware, to handle + any uncaught errors occurring anywhere in the entire stack. + `ExceptionMiddleware` is added as the very innermost middleware, to deal + with handled exception cases occurring in the routing or endpoints. + exception_handlers: A mapping of either integer status codes, + or exception class types onto callables which handle the exceptions. + Exception handler callables should be of the form + `handler(request, exc) -> response` and may be either standard functions, or + async functions. + on_startup: A list of callables to run on application startup. + Startup handler callables do not take any arguments, and may be either + standard functions, or async functions. + on_shutdown: A list of callables to run on application shutdown. + Shutdown handler callables do not take any arguments, and may be either + standard functions, or async functions. + lifespan: A lifespan context function, which can be used to perform + startup and shutdown tasks. This is a newer style that replaces the + `on_startup` and `on_shutdown` handlers. Use one or the other, not both. + """ + # The lifespan context function is a newer style that replaces + # on_startup / on_shutdown handlers. Use one or the other, not both. + assert lifespan is None or (on_startup is None and on_shutdown is None), ( + "Use either 'lifespan' or 'on_startup'/'on_shutdown', not both." + ) + + self.debug = debug + self.state = State() + self.router = Router(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan) + self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers) + self.user_middleware = [] if middleware is None else list(middleware) + self.middleware_stack: ASGIApp | None = None + + def build_middleware_stack(self) -> ASGIApp: + debug = self.debug + error_handler = None + exception_handlers: dict[typing.Any, typing.Callable[[Request, Exception], Response]] = {} + + for key, value in self.exception_handlers.items(): + if key in (500, Exception): + error_handler = value + else: + exception_handlers[key] = value + + middleware = ( + [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)] + + self.user_middleware + + [Middleware(ExceptionMiddleware, handlers=exception_handlers, debug=debug)] + ) + + app = self.router + for cls, args, kwargs in reversed(middleware): + app = cls(app, *args, **kwargs) + return app + + @property + def routes(self) -> list[BaseRoute]: + return self.router.routes + + def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + return self.router.url_path_for(name, **path_params) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + scope["app"] = self + if self.middleware_stack is None: + self.middleware_stack = self.build_middleware_stack() + await self.middleware_stack(scope, receive, send) + + def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg] + return self.router.on_event(event_type) # pragma: no cover + + def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: + self.router.mount(path, app=app, name=name) # pragma: no cover + + def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: + self.router.host(host, app=app, name=name) # pragma: no cover + + def add_middleware( + self, + middleware_class: _MiddlewareFactory[P], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + if self.middleware_stack is not None: # pragma: no cover + raise RuntimeError("Cannot add middleware after an application has started") + self.user_middleware.insert(0, Middleware(middleware_class, *args, **kwargs)) + + def add_exception_handler( + self, + exc_class_or_status_code: int | type[Exception], + handler: ExceptionHandler, + ) -> None: # pragma: no cover + self.exception_handlers[exc_class_or_status_code] = handler + + def add_event_handler( + self, + event_type: str, + func: typing.Callable, # type: ignore[type-arg] + ) -> None: # pragma: no cover + self.router.add_event_handler(event_type, func) + + def add_route( + self, + path: str, + route: typing.Callable[[Request], typing.Awaitable[Response] | Response], + methods: list[str] | None = None, + name: str | None = None, + include_in_schema: bool = True, + ) -> None: # pragma: no cover + self.router.add_route(path, route, methods=methods, name=name, include_in_schema=include_in_schema) + + def add_websocket_route( + self, + path: str, + route: typing.Callable[[WebSocket], typing.Awaitable[None]], + name: str | None = None, + ) -> None: # pragma: no cover + self.router.add_websocket_route(path, route, name=name) + + def exception_handler(self, exc_class_or_status_code: int | type[Exception]) -> typing.Callable: # type: ignore[type-arg] + warnings.warn( + "The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. " + "Refer to https://www.starlette.io/exceptions/ for the recommended approach.", + DeprecationWarning, + ) + + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + self.add_exception_handler(exc_class_or_status_code, func) + return func + + return decorator + + def route( + self, + path: str, + methods: list[str] | None = None, + name: str | None = None, + include_in_schema: bool = True, + ) -> typing.Callable: # type: ignore[type-arg] + """ + We no longer document this decorator style API, and its usage is discouraged. + Instead you should use the following approach: + + >>> routes = [Route(path, endpoint=...), ...] + >>> app = Starlette(routes=routes) + """ + warnings.warn( + "The `route` decorator is deprecated, and will be removed in version 1.0.0. " + "Refer to https://www.starlette.io/routing/ for the recommended approach.", + DeprecationWarning, + ) + + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + self.router.add_route( + path, + func, + methods=methods, + name=name, + include_in_schema=include_in_schema, + ) + return func + + return decorator + + def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg] + """ + We no longer document this decorator style API, and its usage is discouraged. + Instead you should use the following approach: + + >>> routes = [WebSocketRoute(path, endpoint=...), ...] + >>> app = Starlette(routes=routes) + """ + warnings.warn( + "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. " + "Refer to https://www.starlette.io/routing/#websocket-routing for the recommended approach.", + DeprecationWarning, + ) + + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + self.router.add_websocket_route(path, func, name=name) + return func + + return decorator + + def middleware(self, middleware_type: str) -> typing.Callable: # type: ignore[type-arg] + """ + We no longer document this decorator style API, and its usage is discouraged. + Instead you should use the following approach: + + >>> middleware = [Middleware(...), ...] + >>> app = Starlette(middleware=middleware) + """ + warnings.warn( + "The `middleware` decorator is deprecated, and will be removed in version 1.0.0. " + "Refer to https://www.starlette.io/middleware/#using-middleware for recommended approach.", + DeprecationWarning, + ) + assert middleware_type == "http", 'Currently only middleware("http") is supported.' + + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + self.add_middleware(BaseHTTPMiddleware, dispatch=func) + return func + + return decorator diff --git a/.venv/lib/python3.12/site-packages/starlette/authentication.py b/.venv/lib/python3.12/site-packages/starlette/authentication.py new file mode 100644 index 00000000..4fd86641 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/authentication.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import functools +import inspect +import sys +import typing +from urllib.parse import urlencode + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + +from starlette._utils import is_async_callable +from starlette.exceptions import HTTPException +from starlette.requests import HTTPConnection, Request +from starlette.responses import RedirectResponse +from starlette.websockets import WebSocket + +_P = ParamSpec("_P") + + +def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool: + for scope in scopes: + if scope not in conn.auth.scopes: + return False + return True + + +def requires( + scopes: str | typing.Sequence[str], + status_code: int = 403, + redirect: str | None = None, +) -> typing.Callable[[typing.Callable[_P, typing.Any]], typing.Callable[_P, typing.Any]]: + scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) + + def decorator( + func: typing.Callable[_P, typing.Any], + ) -> typing.Callable[_P, typing.Any]: + sig = inspect.signature(func) + for idx, parameter in enumerate(sig.parameters.values()): + if parameter.name == "request" or parameter.name == "websocket": + type_ = parameter.name + break + else: + raise Exception(f'No "request" or "websocket" argument on function "{func}"') + + if type_ == "websocket": + # Handle websocket functions. (Always async) + @functools.wraps(func) + async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: + websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None) + assert isinstance(websocket, WebSocket) + + if not has_required_scope(websocket, scopes_list): + await websocket.close() + else: + await func(*args, **kwargs) + + return websocket_wrapper + + elif is_async_callable(func): + # Handle async request/response functions. + @functools.wraps(func) + async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any: + request = kwargs.get("request", args[idx] if idx < len(args) else None) + assert isinstance(request, Request) + + if not has_required_scope(request, scopes_list): + if redirect is not None: + orig_request_qparam = urlencode({"next": str(request.url)}) + next_url = f"{request.url_for(redirect)}?{orig_request_qparam}" + return RedirectResponse(url=next_url, status_code=303) + raise HTTPException(status_code=status_code) + return await func(*args, **kwargs) + + return async_wrapper + + else: + # Handle sync request/response functions. + @functools.wraps(func) + def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any: + request = kwargs.get("request", args[idx] if idx < len(args) else None) + assert isinstance(request, Request) + + if not has_required_scope(request, scopes_list): + if redirect is not None: + orig_request_qparam = urlencode({"next": str(request.url)}) + next_url = f"{request.url_for(redirect)}?{orig_request_qparam}" + return RedirectResponse(url=next_url, status_code=303) + raise HTTPException(status_code=status_code) + return func(*args, **kwargs) + + return sync_wrapper + + return decorator + + +class AuthenticationError(Exception): + pass + + +class AuthenticationBackend: + async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None: + raise NotImplementedError() # pragma: no cover + + +class AuthCredentials: + def __init__(self, scopes: typing.Sequence[str] | None = None): + self.scopes = [] if scopes is None else list(scopes) + + +class BaseUser: + @property + def is_authenticated(self) -> bool: + raise NotImplementedError() # pragma: no cover + + @property + def display_name(self) -> str: + raise NotImplementedError() # pragma: no cover + + @property + def identity(self) -> str: + raise NotImplementedError() # pragma: no cover + + +class SimpleUser(BaseUser): + def __init__(self, username: str) -> None: + self.username = username + + @property + def is_authenticated(self) -> bool: + return True + + @property + def display_name(self) -> str: + return self.username + + +class UnauthenticatedUser(BaseUser): + @property + def is_authenticated(self) -> bool: + return False + + @property + def display_name(self) -> str: + return "" diff --git a/.venv/lib/python3.12/site-packages/starlette/background.py b/.venv/lib/python3.12/site-packages/starlette/background.py new file mode 100644 index 00000000..0430fc08 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/background.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import sys +import typing + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + +from starlette._utils import is_async_callable +from starlette.concurrency import run_in_threadpool + +P = ParamSpec("P") + + +class BackgroundTask: + def __init__(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None: + self.func = func + self.args = args + self.kwargs = kwargs + self.is_async = is_async_callable(func) + + async def __call__(self) -> None: + if self.is_async: + await self.func(*self.args, **self.kwargs) + else: + await run_in_threadpool(self.func, *self.args, **self.kwargs) + + +class BackgroundTasks(BackgroundTask): + def __init__(self, tasks: typing.Sequence[BackgroundTask] | None = None): + self.tasks = list(tasks) if tasks else [] + + def add_task(self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs) -> None: + task = BackgroundTask(func, *args, **kwargs) + self.tasks.append(task) + + async def __call__(self) -> None: + for task in self.tasks: + await task() diff --git a/.venv/lib/python3.12/site-packages/starlette/concurrency.py b/.venv/lib/python3.12/site-packages/starlette/concurrency.py new file mode 100644 index 00000000..494f3420 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/concurrency.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import functools +import sys +import typing +import warnings + +import anyio.to_thread + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + +P = ParamSpec("P") +T = typing.TypeVar("T") + + +async def run_until_first_complete(*args: tuple[typing.Callable, dict]) -> None: # type: ignore[type-arg] + warnings.warn( + "run_until_first_complete is deprecated and will be removed in a future version.", + DeprecationWarning, + ) + + async with anyio.create_task_group() as task_group: + + async def run(func: typing.Callable[[], typing.Coroutine]) -> None: # type: ignore[type-arg] + await func() + task_group.cancel_scope.cancel() + + for func, kwargs in args: + task_group.start_soon(run, functools.partial(func, **kwargs)) + + +async def run_in_threadpool(func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + func = functools.partial(func, *args, **kwargs) + return await anyio.to_thread.run_sync(func) + + +class _StopIteration(Exception): + pass + + +def _next(iterator: typing.Iterator[T]) -> T: + # We can't raise `StopIteration` from within the threadpool iterator + # and catch it outside that context, so we coerce them into a different + # exception type. + try: + return next(iterator) + except StopIteration: + raise _StopIteration + + +async def iterate_in_threadpool( + iterator: typing.Iterable[T], +) -> typing.AsyncIterator[T]: + as_iterator = iter(iterator) + while True: + try: + yield await anyio.to_thread.run_sync(_next, as_iterator) + except _StopIteration: + break diff --git a/.venv/lib/python3.12/site-packages/starlette/config.py b/.venv/lib/python3.12/site-packages/starlette/config.py new file mode 100644 index 00000000..ca15c564 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/config.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import os +import typing +import warnings +from pathlib import Path + + +class undefined: + pass + + +class EnvironError(Exception): + pass + + +class Environ(typing.MutableMapping[str, str]): + def __init__(self, environ: typing.MutableMapping[str, str] = os.environ): + self._environ = environ + self._has_been_read: set[str] = set() + + def __getitem__(self, key: str) -> str: + self._has_been_read.add(key) + return self._environ.__getitem__(key) + + def __setitem__(self, key: str, value: str) -> None: + if key in self._has_been_read: + raise EnvironError(f"Attempting to set environ['{key}'], but the value has already been read.") + self._environ.__setitem__(key, value) + + def __delitem__(self, key: str) -> None: + if key in self._has_been_read: + raise EnvironError(f"Attempting to delete environ['{key}'], but the value has already been read.") + self._environ.__delitem__(key) + + def __iter__(self) -> typing.Iterator[str]: + return iter(self._environ) + + def __len__(self) -> int: + return len(self._environ) + + +environ = Environ() + +T = typing.TypeVar("T") + + +class Config: + def __init__( + self, + env_file: str | Path | None = None, + environ: typing.Mapping[str, str] = environ, + env_prefix: str = "", + ) -> None: + self.environ = environ + self.env_prefix = env_prefix + self.file_values: dict[str, str] = {} + if env_file is not None: + if not os.path.isfile(env_file): + warnings.warn(f"Config file '{env_file}' not found.") + else: + self.file_values = self._read_file(env_file) + + @typing.overload + def __call__(self, key: str, *, default: None) -> str | None: ... + + @typing.overload + def __call__(self, key: str, cast: type[T], default: T = ...) -> T: ... + + @typing.overload + def __call__(self, key: str, cast: type[str] = ..., default: str = ...) -> str: ... + + @typing.overload + def __call__( + self, + key: str, + cast: typing.Callable[[typing.Any], T] = ..., + default: typing.Any = ..., + ) -> T: ... + + @typing.overload + def __call__(self, key: str, cast: type[str] = ..., default: T = ...) -> T | str: ... + + def __call__( + self, + key: str, + cast: typing.Callable[[typing.Any], typing.Any] | None = None, + default: typing.Any = undefined, + ) -> typing.Any: + return self.get(key, cast, default) + + def get( + self, + key: str, + cast: typing.Callable[[typing.Any], typing.Any] | None = None, + default: typing.Any = undefined, + ) -> typing.Any: + key = self.env_prefix + key + if key in self.environ: + value = self.environ[key] + return self._perform_cast(key, value, cast) + if key in self.file_values: + value = self.file_values[key] + return self._perform_cast(key, value, cast) + if default is not undefined: + return self._perform_cast(key, default, cast) + raise KeyError(f"Config '{key}' is missing, and has no default.") + + def _read_file(self, file_name: str | Path) -> dict[str, str]: + file_values: dict[str, str] = {} + with open(file_name) as input_file: + for line in input_file.readlines(): + line = line.strip() + if "=" in line and not line.startswith("#"): + key, value = line.split("=", 1) + key = key.strip() + value = value.strip().strip("\"'") + file_values[key] = value + return file_values + + def _perform_cast( + self, + key: str, + value: typing.Any, + cast: typing.Callable[[typing.Any], typing.Any] | None = None, + ) -> typing.Any: + if cast is None or value is None: + return value + elif cast is bool and isinstance(value, str): + mapping = {"true": True, "1": True, "false": False, "0": False} + value = value.lower() + if value not in mapping: + raise ValueError(f"Config '{key}' has value '{value}'. Not a valid bool.") + return mapping[value] + try: + return cast(value) + except (TypeError, ValueError): + raise ValueError(f"Config '{key}' has value '{value}'. Not a valid {cast.__name__}.") diff --git a/.venv/lib/python3.12/site-packages/starlette/convertors.py b/.venv/lib/python3.12/site-packages/starlette/convertors.py new file mode 100644 index 00000000..84df87a5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/convertors.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import math +import typing +import uuid + +T = typing.TypeVar("T") + + +class Convertor(typing.Generic[T]): + regex: typing.ClassVar[str] = "" + + def convert(self, value: str) -> T: + raise NotImplementedError() # pragma: no cover + + def to_string(self, value: T) -> str: + raise NotImplementedError() # pragma: no cover + + +class StringConvertor(Convertor[str]): + regex = "[^/]+" + + def convert(self, value: str) -> str: + return value + + def to_string(self, value: str) -> str: + value = str(value) + assert "/" not in value, "May not contain path separators" + assert value, "Must not be empty" + return value + + +class PathConvertor(Convertor[str]): + regex = ".*" + + def convert(self, value: str) -> str: + return str(value) + + def to_string(self, value: str) -> str: + return str(value) + + +class IntegerConvertor(Convertor[int]): + regex = "[0-9]+" + + def convert(self, value: str) -> int: + return int(value) + + def to_string(self, value: int) -> str: + value = int(value) + assert value >= 0, "Negative integers are not supported" + return str(value) + + +class FloatConvertor(Convertor[float]): + regex = r"[0-9]+(\.[0-9]+)?" + + def convert(self, value: str) -> float: + return float(value) + + def to_string(self, value: float) -> str: + value = float(value) + assert value >= 0.0, "Negative floats are not supported" + assert not math.isnan(value), "NaN values are not supported" + assert not math.isinf(value), "Infinite values are not supported" + return ("%0.20f" % value).rstrip("0").rstrip(".") + + +class UUIDConvertor(Convertor[uuid.UUID]): + regex = "[0-9a-fA-F]{8}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{4}-?[0-9a-fA-F]{12}" + + def convert(self, value: str) -> uuid.UUID: + return uuid.UUID(value) + + def to_string(self, value: uuid.UUID) -> str: + return str(value) + + +CONVERTOR_TYPES: dict[str, Convertor[typing.Any]] = { + "str": StringConvertor(), + "path": PathConvertor(), + "int": IntegerConvertor(), + "float": FloatConvertor(), + "uuid": UUIDConvertor(), +} + + +def register_url_convertor(key: str, convertor: Convertor[typing.Any]) -> None: + CONVERTOR_TYPES[key] = convertor diff --git a/.venv/lib/python3.12/site-packages/starlette/datastructures.py b/.venv/lib/python3.12/site-packages/starlette/datastructures.py new file mode 100644 index 00000000..f5d74d25 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/datastructures.py @@ -0,0 +1,674 @@ +from __future__ import annotations + +import typing +from shlex import shlex +from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit + +from starlette.concurrency import run_in_threadpool +from starlette.types import Scope + + +class Address(typing.NamedTuple): + host: str + port: int + + +_KeyType = typing.TypeVar("_KeyType") +# Mapping keys are invariant but their values are covariant since +# you can only read them +# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()` +_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True) + + +class URL: + def __init__( + self, + url: str = "", + scope: Scope | None = None, + **components: typing.Any, + ) -> None: + if scope is not None: + assert not url, 'Cannot set both "url" and "scope".' + assert not components, 'Cannot set both "scope" and "**components".' + scheme = scope.get("scheme", "http") + server = scope.get("server", None) + path = scope["path"] + query_string = scope.get("query_string", b"") + + host_header = None + for key, value in scope["headers"]: + if key == b"host": + host_header = value.decode("latin-1") + break + + if host_header is not None: + url = f"{scheme}://{host_header}{path}" + elif server is None: + url = path + else: + host, port = server + default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] + if port == default_port: + url = f"{scheme}://{host}{path}" + else: + url = f"{scheme}://{host}:{port}{path}" + + if query_string: + url += "?" + query_string.decode() + elif components: + assert not url, 'Cannot set both "url" and "**components".' + url = URL("").replace(**components).components.geturl() + + self._url = url + + @property + def components(self) -> SplitResult: + if not hasattr(self, "_components"): + self._components = urlsplit(self._url) + return self._components + + @property + def scheme(self) -> str: + return self.components.scheme + + @property + def netloc(self) -> str: + return self.components.netloc + + @property + def path(self) -> str: + return self.components.path + + @property + def query(self) -> str: + return self.components.query + + @property + def fragment(self) -> str: + return self.components.fragment + + @property + def username(self) -> None | str: + return self.components.username + + @property + def password(self) -> None | str: + return self.components.password + + @property + def hostname(self) -> None | str: + return self.components.hostname + + @property + def port(self) -> int | None: + return self.components.port + + @property + def is_secure(self) -> bool: + return self.scheme in ("https", "wss") + + def replace(self, **kwargs: typing.Any) -> URL: + if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs: + hostname = kwargs.pop("hostname", None) + port = kwargs.pop("port", self.port) + username = kwargs.pop("username", self.username) + password = kwargs.pop("password", self.password) + + if hostname is None: + netloc = self.netloc + _, _, hostname = netloc.rpartition("@") + + if hostname[-1] != "]": + hostname = hostname.rsplit(":", 1)[0] + + netloc = hostname + if port is not None: + netloc += f":{port}" + if username is not None: + userpass = username + if password is not None: + userpass += f":{password}" + netloc = f"{userpass}@{netloc}" + + kwargs["netloc"] = netloc + + components = self.components._replace(**kwargs) + return self.__class__(components.geturl()) + + def include_query_params(self, **kwargs: typing.Any) -> URL: + params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) + params.update({str(key): str(value) for key, value in kwargs.items()}) + query = urlencode(params.multi_items()) + return self.replace(query=query) + + def replace_query_params(self, **kwargs: typing.Any) -> URL: + query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) + return self.replace(query=query) + + def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL: + if isinstance(keys, str): + keys = [keys] + params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) + for key in keys: + params.pop(key, None) + query = urlencode(params.multi_items()) + return self.replace(query=query) + + def __eq__(self, other: typing.Any) -> bool: + return str(self) == str(other) + + def __str__(self) -> str: + return self._url + + def __repr__(self) -> str: + url = str(self) + if self.password: + url = str(self.replace(password="********")) + return f"{self.__class__.__name__}({repr(url)})" + + +class URLPath(str): + """ + A URL path string that may also hold an associated protocol and/or host. + Used by the routing to return `url_path_for` matches. + """ + + def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath: + assert protocol in ("http", "websocket", "") + return str.__new__(cls, path) + + def __init__(self, path: str, protocol: str = "", host: str = "") -> None: + self.protocol = protocol + self.host = host + + def make_absolute_url(self, base_url: str | URL) -> URL: + if isinstance(base_url, str): + base_url = URL(base_url) + if self.protocol: + scheme = { + "http": {True: "https", False: "http"}, + "websocket": {True: "wss", False: "ws"}, + }[self.protocol][base_url.is_secure] + else: + scheme = base_url.scheme + + netloc = self.host or base_url.netloc + path = base_url.path.rstrip("/") + str(self) + return URL(scheme=scheme, netloc=netloc, path=path) + + +class Secret: + """ + Holds a string value that should not be revealed in tracebacks etc. + You should cast the value to `str` at the point it is required. + """ + + def __init__(self, value: str): + self._value = value + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}('**********')" + + def __str__(self) -> str: + return self._value + + def __bool__(self) -> bool: + return bool(self._value) + + +class CommaSeparatedStrings(typing.Sequence[str]): + def __init__(self, value: str | typing.Sequence[str]): + if isinstance(value, str): + splitter = shlex(value, posix=True) + splitter.whitespace = "," + splitter.whitespace_split = True + self._items = [item.strip() for item in splitter] + else: + self._items = list(value) + + def __len__(self) -> int: + return len(self._items) + + def __getitem__(self, index: int | slice) -> typing.Any: + return self._items[index] + + def __iter__(self) -> typing.Iterator[str]: + return iter(self._items) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + items = [item for item in self] + return f"{class_name}({items!r})" + + def __str__(self) -> str: + return ", ".join(repr(item) for item in self) + + +class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]): + _dict: dict[_KeyType, _CovariantValueType] + + def __init__( + self, + *args: ImmutableMultiDict[_KeyType, _CovariantValueType] + | typing.Mapping[_KeyType, _CovariantValueType] + | typing.Iterable[tuple[_KeyType, _CovariantValueType]], + **kwargs: typing.Any, + ) -> None: + assert len(args) < 2, "Too many arguments." + + value: typing.Any = args[0] if args else [] + if kwargs: + value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items() + + if not value: + _items: list[tuple[typing.Any, typing.Any]] = [] + elif hasattr(value, "multi_items"): + value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value) + _items = list(value.multi_items()) + elif hasattr(value, "items"): + value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value) + _items = list(value.items()) + else: + value = typing.cast("list[tuple[typing.Any, typing.Any]]", value) + _items = list(value) + + self._dict = {k: v for k, v in _items} + self._list = _items + + def getlist(self, key: typing.Any) -> list[_CovariantValueType]: + return [item_value for item_key, item_value in self._list if item_key == key] + + def keys(self) -> typing.KeysView[_KeyType]: + return self._dict.keys() + + def values(self) -> typing.ValuesView[_CovariantValueType]: + return self._dict.values() + + def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]: + return self._dict.items() + + def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]: + return list(self._list) + + def __getitem__(self, key: _KeyType) -> _CovariantValueType: + return self._dict[key] + + def __contains__(self, key: typing.Any) -> bool: + return key in self._dict + + def __iter__(self) -> typing.Iterator[_KeyType]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._dict) + + def __eq__(self, other: typing.Any) -> bool: + if not isinstance(other, self.__class__): + return False + return sorted(self._list) == sorted(other._list) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + items = self.multi_items() + return f"{class_name}({items!r})" + + +class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]): + def __setitem__(self, key: typing.Any, value: typing.Any) -> None: + self.setlist(key, [value]) + + def __delitem__(self, key: typing.Any) -> None: + self._list = [(k, v) for k, v in self._list if k != key] + del self._dict[key] + + def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + self._list = [(k, v) for k, v in self._list if k != key] + return self._dict.pop(key, default) + + def popitem(self) -> tuple[typing.Any, typing.Any]: + key, value = self._dict.popitem() + self._list = [(k, v) for k, v in self._list if k != key] + return key, value + + def poplist(self, key: typing.Any) -> list[typing.Any]: + values = [v for k, v in self._list if k == key] + self.pop(key) + return values + + def clear(self) -> None: + self._dict.clear() + self._list.clear() + + def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + if key not in self: + self._dict[key] = default + self._list.append((key, default)) + + return self[key] + + def setlist(self, key: typing.Any, values: list[typing.Any]) -> None: + if not values: + self.pop(key, None) + else: + existing_items = [(k, v) for (k, v) in self._list if k != key] + self._list = existing_items + [(key, value) for value in values] + self._dict[key] = values[-1] + + def append(self, key: typing.Any, value: typing.Any) -> None: + self._list.append((key, value)) + self._dict[key] = value + + def update( + self, + *args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]], + **kwargs: typing.Any, + ) -> None: + value = MultiDict(*args, **kwargs) + existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()] + self._list = existing_items + value.multi_items() + self._dict.update(value) + + +class QueryParams(ImmutableMultiDict[str, str]): + """ + An immutable multidict. + """ + + def __init__( + self, + *args: ImmutableMultiDict[typing.Any, typing.Any] + | typing.Mapping[typing.Any, typing.Any] + | list[tuple[typing.Any, typing.Any]] + | str + | bytes, + **kwargs: typing.Any, + ) -> None: + assert len(args) < 2, "Too many arguments." + + value = args[0] if args else [] + + if isinstance(value, str): + super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs) + elif isinstance(value, bytes): + super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs) + else: + super().__init__(*args, **kwargs) # type: ignore[arg-type] + self._list = [(str(k), str(v)) for k, v in self._list] + self._dict = {str(k): str(v) for k, v in self._dict.items()} + + def __str__(self) -> str: + return urlencode(self._list) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + query_string = str(self) + return f"{class_name}({query_string!r})" + + +class UploadFile: + """ + An uploaded file included as part of the request data. + """ + + def __init__( + self, + file: typing.BinaryIO, + *, + size: int | None = None, + filename: str | None = None, + headers: Headers | None = None, + ) -> None: + self.filename = filename + self.file = file + self.size = size + self.headers = headers or Headers() + + @property + def content_type(self) -> str | None: + return self.headers.get("content-type", None) + + @property + def _in_memory(self) -> bool: + # check for SpooledTemporaryFile._rolled + rolled_to_disk = getattr(self.file, "_rolled", True) + return not rolled_to_disk + + async def write(self, data: bytes) -> None: + if self.size is not None: + self.size += len(data) + + if self._in_memory: + self.file.write(data) + else: + await run_in_threadpool(self.file.write, data) + + async def read(self, size: int = -1) -> bytes: + if self._in_memory: + return self.file.read(size) + return await run_in_threadpool(self.file.read, size) + + async def seek(self, offset: int) -> None: + if self._in_memory: + self.file.seek(offset) + else: + await run_in_threadpool(self.file.seek, offset) + + async def close(self) -> None: + if self._in_memory: + self.file.close() + else: + await run_in_threadpool(self.file.close) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})" + + +class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]): + """ + An immutable multidict, containing both file uploads and text input. + """ + + def __init__( + self, + *args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]], + **kwargs: str | UploadFile, + ) -> None: + super().__init__(*args, **kwargs) + + async def close(self) -> None: + for key, value in self.multi_items(): + if isinstance(value, UploadFile): + await value.close() + + +class Headers(typing.Mapping[str, str]): + """ + An immutable, case-insensitive multidict. + """ + + def __init__( + self, + headers: typing.Mapping[str, str] | None = None, + raw: list[tuple[bytes, bytes]] | None = None, + scope: typing.MutableMapping[str, typing.Any] | None = None, + ) -> None: + self._list: list[tuple[bytes, bytes]] = [] + if headers is not None: + assert raw is None, 'Cannot set both "headers" and "raw".' + assert scope is None, 'Cannot set both "headers" and "scope".' + self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()] + elif raw is not None: + assert scope is None, 'Cannot set both "raw" and "scope".' + self._list = raw + elif scope is not None: + # scope["headers"] isn't necessarily a list + # it might be a tuple or other iterable + self._list = scope["headers"] = list(scope["headers"]) + + @property + def raw(self) -> list[tuple[bytes, bytes]]: + return list(self._list) + + def keys(self) -> list[str]: # type: ignore[override] + return [key.decode("latin-1") for key, value in self._list] + + def values(self) -> list[str]: # type: ignore[override] + return [value.decode("latin-1") for key, value in self._list] + + def items(self) -> list[tuple[str, str]]: # type: ignore[override] + return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list] + + def getlist(self, key: str) -> list[str]: + get_header_key = key.lower().encode("latin-1") + return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key] + + def mutablecopy(self) -> MutableHeaders: + return MutableHeaders(raw=self._list[:]) + + def __getitem__(self, key: str) -> str: + get_header_key = key.lower().encode("latin-1") + for header_key, header_value in self._list: + if header_key == get_header_key: + return header_value.decode("latin-1") + raise KeyError(key) + + def __contains__(self, key: typing.Any) -> bool: + get_header_key = key.lower().encode("latin-1") + for header_key, header_value in self._list: + if header_key == get_header_key: + return True + return False + + def __iter__(self) -> typing.Iterator[typing.Any]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._list) + + def __eq__(self, other: typing.Any) -> bool: + if not isinstance(other, Headers): + return False + return sorted(self._list) == sorted(other._list) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + as_dict = dict(self.items()) + if len(as_dict) == len(self): + return f"{class_name}({as_dict!r})" + return f"{class_name}(raw={self.raw!r})" + + +class MutableHeaders(Headers): + def __setitem__(self, key: str, value: str) -> None: + """ + Set the header `key` to `value`, removing any duplicate entries. + Retains insertion order. + """ + set_key = key.lower().encode("latin-1") + set_value = value.encode("latin-1") + + found_indexes: list[int] = [] + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == set_key: + found_indexes.append(idx) + + for idx in reversed(found_indexes[1:]): + del self._list[idx] + + if found_indexes: + idx = found_indexes[0] + self._list[idx] = (set_key, set_value) + else: + self._list.append((set_key, set_value)) + + def __delitem__(self, key: str) -> None: + """ + Remove the header `key`. + """ + del_key = key.lower().encode("latin-1") + + pop_indexes: list[int] = [] + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == del_key: + pop_indexes.append(idx) + + for idx in reversed(pop_indexes): + del self._list[idx] + + def __ior__(self, other: typing.Mapping[str, str]) -> MutableHeaders: + if not isinstance(other, typing.Mapping): + raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") + self.update(other) + return self + + def __or__(self, other: typing.Mapping[str, str]) -> MutableHeaders: + if not isinstance(other, typing.Mapping): + raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") + new = self.mutablecopy() + new.update(other) + return new + + @property + def raw(self) -> list[tuple[bytes, bytes]]: + return self._list + + def setdefault(self, key: str, value: str) -> str: + """ + If the header `key` does not exist, then set it to `value`. + Returns the header value. + """ + set_key = key.lower().encode("latin-1") + set_value = value.encode("latin-1") + + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == set_key: + return item_value.decode("latin-1") + self._list.append((set_key, set_value)) + return value + + def update(self, other: typing.Mapping[str, str]) -> None: + for key, val in other.items(): + self[key] = val + + def append(self, key: str, value: str) -> None: + """ + Append a header, preserving any duplicate entries. + """ + append_key = key.lower().encode("latin-1") + append_value = value.encode("latin-1") + self._list.append((append_key, append_value)) + + def add_vary_header(self, vary: str) -> None: + existing = self.get("vary") + if existing is not None: + vary = ", ".join([existing, vary]) + self["vary"] = vary + + +class State: + """ + An object that can be used to store arbitrary state. + + Used for `request.state` and `app.state`. + """ + + _state: dict[str, typing.Any] + + def __init__(self, state: dict[str, typing.Any] | None = None): + if state is None: + state = {} + super().__setattr__("_state", state) + + def __setattr__(self, key: typing.Any, value: typing.Any) -> None: + self._state[key] = value + + def __getattr__(self, key: typing.Any) -> typing.Any: + try: + return self._state[key] + except KeyError: + message = "'{}' object has no attribute '{}'" + raise AttributeError(message.format(self.__class__.__name__, key)) + + def __delattr__(self, key: typing.Any) -> None: + del self._state[key] diff --git a/.venv/lib/python3.12/site-packages/starlette/endpoints.py b/.venv/lib/python3.12/site-packages/starlette/endpoints.py new file mode 100644 index 00000000..10769026 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/endpoints.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import json +import typing + +from starlette import status +from starlette._utils import is_async_callable +from starlette.concurrency import run_in_threadpool +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response +from starlette.types import Message, Receive, Scope, Send +from starlette.websockets import WebSocket + + +class HTTPEndpoint: + def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + self.scope = scope + self.receive = receive + self.send = send + self._allowed_methods = [ + method + for method in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS") + if getattr(self, method.lower(), None) is not None + ] + + def __await__(self) -> typing.Generator[typing.Any, None, None]: + return self.dispatch().__await__() + + async def dispatch(self) -> None: + request = Request(self.scope, receive=self.receive) + handler_name = "get" if request.method == "HEAD" and not hasattr(self, "head") else request.method.lower() + + handler: typing.Callable[[Request], typing.Any] = getattr(self, handler_name, self.method_not_allowed) + is_async = is_async_callable(handler) + if is_async: + response = await handler(request) + else: + response = await run_in_threadpool(handler, request) + await response(self.scope, self.receive, self.send) + + async def method_not_allowed(self, request: Request) -> Response: + # If we're running inside a starlette application then raise an + # exception, so that the configurable exception handler can deal with + # returning the response. For plain ASGI apps, just return the response. + headers = {"Allow": ", ".join(self._allowed_methods)} + if "app" in self.scope: + raise HTTPException(status_code=405, headers=headers) + return PlainTextResponse("Method Not Allowed", status_code=405, headers=headers) + + +class WebSocketEndpoint: + encoding: str | None = None # May be "text", "bytes", or "json". + + def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "websocket" + self.scope = scope + self.receive = receive + self.send = send + + def __await__(self) -> typing.Generator[typing.Any, None, None]: + return self.dispatch().__await__() + + async def dispatch(self) -> None: + websocket = WebSocket(self.scope, receive=self.receive, send=self.send) + await self.on_connect(websocket) + + close_code = status.WS_1000_NORMAL_CLOSURE + + try: + while True: + message = await websocket.receive() + if message["type"] == "websocket.receive": + data = await self.decode(websocket, message) + await self.on_receive(websocket, data) + elif message["type"] == "websocket.disconnect": # pragma: no branch + close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE) + break + except Exception as exc: + close_code = status.WS_1011_INTERNAL_ERROR + raise exc + finally: + await self.on_disconnect(websocket, close_code) + + async def decode(self, websocket: WebSocket, message: Message) -> typing.Any: + if self.encoding == "text": + if "text" not in message: + await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) + raise RuntimeError("Expected text websocket messages, but got bytes") + return message["text"] + + elif self.encoding == "bytes": + if "bytes" not in message: + await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) + raise RuntimeError("Expected bytes websocket messages, but got text") + return message["bytes"] + + elif self.encoding == "json": + if message.get("text") is not None: + text = message["text"] + else: + text = message["bytes"].decode("utf-8") + + try: + return json.loads(text) + except json.decoder.JSONDecodeError: + await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) + raise RuntimeError("Malformed JSON data received.") + + assert self.encoding is None, f"Unsupported 'encoding' attribute {self.encoding}" + return message["text"] if message.get("text") else message["bytes"] + + async def on_connect(self, websocket: WebSocket) -> None: + """Override to handle an incoming websocket connection""" + await websocket.accept() + + async def on_receive(self, websocket: WebSocket, data: typing.Any) -> None: + """Override to handle an incoming websocket message""" + + async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: + """Override to handle a disconnecting websocket""" diff --git a/.venv/lib/python3.12/site-packages/starlette/exceptions.py b/.venv/lib/python3.12/site-packages/starlette/exceptions.py new file mode 100644 index 00000000..9ad3527b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/exceptions.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import http +from collections.abc import Mapping + + +class HTTPException(Exception): + def __init__(self, status_code: int, detail: str | None = None, headers: Mapping[str, str] | None = None) -> None: + if detail is None: + detail = http.HTTPStatus(status_code).phrase + self.status_code = status_code + self.detail = detail + self.headers = headers + + def __str__(self) -> str: + return f"{self.status_code}: {self.detail}" + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})" + + +class WebSocketException(Exception): + def __init__(self, code: int, reason: str | None = None) -> None: + self.code = code + self.reason = reason or "" + + def __str__(self) -> str: + return f"{self.code}: {self.reason}" + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}(code={self.code!r}, reason={self.reason!r})" diff --git a/.venv/lib/python3.12/site-packages/starlette/formparsers.py b/.venv/lib/python3.12/site-packages/starlette/formparsers.py new file mode 100644 index 00000000..4551d688 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/formparsers.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +import typing +from dataclasses import dataclass, field +from enum import Enum +from tempfile import SpooledTemporaryFile +from urllib.parse import unquote_plus + +from starlette.datastructures import FormData, Headers, UploadFile + +if typing.TYPE_CHECKING: + import python_multipart as multipart + from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header +else: + try: + try: + import python_multipart as multipart + from python_multipart.multipart import parse_options_header + except ModuleNotFoundError: # pragma: no cover + import multipart + from multipart.multipart import parse_options_header + except ModuleNotFoundError: # pragma: no cover + multipart = None + parse_options_header = None + + +class FormMessage(Enum): + FIELD_START = 1 + FIELD_NAME = 2 + FIELD_DATA = 3 + FIELD_END = 4 + END = 5 + + +@dataclass +class MultipartPart: + content_disposition: bytes | None = None + field_name: str = "" + data: bytearray = field(default_factory=bytearray) + file: UploadFile | None = None + item_headers: list[tuple[bytes, bytes]] = field(default_factory=list) + + +def _user_safe_decode(src: bytes | bytearray, codec: str) -> str: + try: + return src.decode(codec) + except (UnicodeDecodeError, LookupError): + return src.decode("latin-1") + + +class MultiPartException(Exception): + def __init__(self, message: str) -> None: + self.message = message + + +class FormParser: + def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None: + assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." + self.headers = headers + self.stream = stream + self.messages: list[tuple[FormMessage, bytes]] = [] + + def on_field_start(self) -> None: + message = (FormMessage.FIELD_START, b"") + self.messages.append(message) + + def on_field_name(self, data: bytes, start: int, end: int) -> None: + message = (FormMessage.FIELD_NAME, data[start:end]) + self.messages.append(message) + + def on_field_data(self, data: bytes, start: int, end: int) -> None: + message = (FormMessage.FIELD_DATA, data[start:end]) + self.messages.append(message) + + def on_field_end(self) -> None: + message = (FormMessage.FIELD_END, b"") + self.messages.append(message) + + def on_end(self) -> None: + message = (FormMessage.END, b"") + self.messages.append(message) + + async def parse(self) -> FormData: + # Callbacks dictionary. + callbacks: QuerystringCallbacks = { + "on_field_start": self.on_field_start, + "on_field_name": self.on_field_name, + "on_field_data": self.on_field_data, + "on_field_end": self.on_field_end, + "on_end": self.on_end, + } + + # Create the parser. + parser = multipart.QuerystringParser(callbacks) + field_name = b"" + field_value = b"" + + items: list[tuple[str, str | UploadFile]] = [] + + # Feed the parser with data from the request. + async for chunk in self.stream: + if chunk: + parser.write(chunk) + else: + parser.finalize() + messages = list(self.messages) + self.messages.clear() + for message_type, message_bytes in messages: + if message_type == FormMessage.FIELD_START: + field_name = b"" + field_value = b"" + elif message_type == FormMessage.FIELD_NAME: + field_name += message_bytes + elif message_type == FormMessage.FIELD_DATA: + field_value += message_bytes + elif message_type == FormMessage.FIELD_END: + name = unquote_plus(field_name.decode("latin-1")) + value = unquote_plus(field_value.decode("latin-1")) + items.append((name, value)) + + return FormData(items) + + +class MultiPartParser: + spool_max_size = 1024 * 1024 # 1MB + """The maximum size of the spooled temporary file used to store file data.""" + max_part_size = 1024 * 1024 # 1MB + """The maximum size of a part in the multipart request.""" + + def __init__( + self, + headers: Headers, + stream: typing.AsyncGenerator[bytes, None], + *, + max_files: int | float = 1000, + max_fields: int | float = 1000, + max_part_size: int = 1024 * 1024, # 1MB + ) -> None: + assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." + self.headers = headers + self.stream = stream + self.max_files = max_files + self.max_fields = max_fields + self.items: list[tuple[str, str | UploadFile]] = [] + self._current_files = 0 + self._current_fields = 0 + self._current_partial_header_name: bytes = b"" + self._current_partial_header_value: bytes = b"" + self._current_part = MultipartPart() + self._charset = "" + self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = [] + self._file_parts_to_finish: list[MultipartPart] = [] + self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = [] + self.max_part_size = max_part_size + + def on_part_begin(self) -> None: + self._current_part = MultipartPart() + + def on_part_data(self, data: bytes, start: int, end: int) -> None: + message_bytes = data[start:end] + if self._current_part.file is None: + if len(self._current_part.data) + len(message_bytes) > self.max_part_size: + raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.") + self._current_part.data.extend(message_bytes) + else: + self._file_parts_to_write.append((self._current_part, message_bytes)) + + def on_part_end(self) -> None: + if self._current_part.file is None: + self.items.append( + ( + self._current_part.field_name, + _user_safe_decode(self._current_part.data, self._charset), + ) + ) + else: + self._file_parts_to_finish.append(self._current_part) + # The file can be added to the items right now even though it's not + # finished yet, because it will be finished in the `parse()` method, before + # self.items is used in the return value. + self.items.append((self._current_part.field_name, self._current_part.file)) + + def on_header_field(self, data: bytes, start: int, end: int) -> None: + self._current_partial_header_name += data[start:end] + + def on_header_value(self, data: bytes, start: int, end: int) -> None: + self._current_partial_header_value += data[start:end] + + def on_header_end(self) -> None: + field = self._current_partial_header_name.lower() + if field == b"content-disposition": + self._current_part.content_disposition = self._current_partial_header_value + self._current_part.item_headers.append((field, self._current_partial_header_value)) + self._current_partial_header_name = b"" + self._current_partial_header_value = b"" + + def on_headers_finished(self) -> None: + disposition, options = parse_options_header(self._current_part.content_disposition) + try: + self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset) + except KeyError: + raise MultiPartException('The Content-Disposition header field "name" must be provided.') + if b"filename" in options: + self._current_files += 1 + if self._current_files > self.max_files: + raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.") + filename = _user_safe_decode(options[b"filename"], self._charset) + tempfile = SpooledTemporaryFile(max_size=self.spool_max_size) + self._files_to_close_on_error.append(tempfile) + self._current_part.file = UploadFile( + file=tempfile, # type: ignore[arg-type] + size=0, + filename=filename, + headers=Headers(raw=self._current_part.item_headers), + ) + else: + self._current_fields += 1 + if self._current_fields > self.max_fields: + raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.") + self._current_part.file = None + + def on_end(self) -> None: + pass + + async def parse(self) -> FormData: + # Parse the Content-Type header to get the multipart boundary. + _, params = parse_options_header(self.headers["Content-Type"]) + charset = params.get(b"charset", "utf-8") + if isinstance(charset, bytes): + charset = charset.decode("latin-1") + self._charset = charset + try: + boundary = params[b"boundary"] + except KeyError: + raise MultiPartException("Missing boundary in multipart.") + + # Callbacks dictionary. + callbacks: MultipartCallbacks = { + "on_part_begin": self.on_part_begin, + "on_part_data": self.on_part_data, + "on_part_end": self.on_part_end, + "on_header_field": self.on_header_field, + "on_header_value": self.on_header_value, + "on_header_end": self.on_header_end, + "on_headers_finished": self.on_headers_finished, + "on_end": self.on_end, + } + + # Create the parser. + parser = multipart.MultipartParser(boundary, callbacks) + try: + # Feed the parser with data from the request. + async for chunk in self.stream: + parser.write(chunk) + # Write file data, it needs to use await with the UploadFile methods + # that call the corresponding file methods *in a threadpool*, + # otherwise, if they were called directly in the callback methods above + # (regular, non-async functions), that would block the event loop in + # the main thread. + for part, data in self._file_parts_to_write: + assert part.file # for type checkers + await part.file.write(data) + for part in self._file_parts_to_finish: + assert part.file # for type checkers + await part.file.seek(0) + self._file_parts_to_write.clear() + self._file_parts_to_finish.clear() + except MultiPartException as exc: + # Close all the files if there was an error. + for file in self._files_to_close_on_error: + file.close() + raise exc + + parser.finalize() + return FormData(self.items) diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/__init__.py b/.venv/lib/python3.12/site-packages/starlette/middleware/__init__.py new file mode 100644 index 00000000..b99538a2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/__init__.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import sys +from collections.abc import Iterator +from typing import Any, Protocol + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + +from starlette.types import ASGIApp + +P = ParamSpec("P") + + +class _MiddlewareFactory(Protocol[P]): + def __call__(self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover + + +class Middleware: + def __init__( + self, + cls: _MiddlewareFactory[P], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + self.cls = cls + self.args = args + self.kwargs = kwargs + + def __iter__(self) -> Iterator[Any]: + as_tuple = (self.cls, self.args, self.kwargs) + return iter(as_tuple) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + args_strings = [f"{value!r}" for value in self.args] + option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()] + name = getattr(self.cls, "__name__", "") + args_repr = ", ".join([name] + args_strings + option_strings) + return f"{class_name}({args_repr})" diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py b/.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py new file mode 100644 index 00000000..8555ee07 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/authentication.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import typing + +from starlette.authentication import ( + AuthCredentials, + AuthenticationBackend, + AuthenticationError, + UnauthenticatedUser, +) +from starlette.requests import HTTPConnection +from starlette.responses import PlainTextResponse, Response +from starlette.types import ASGIApp, Receive, Scope, Send + + +class AuthenticationMiddleware: + def __init__( + self, + app: ASGIApp, + backend: AuthenticationBackend, + on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] | None = None, + ) -> None: + self.app = app + self.backend = backend + self.on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] = ( + on_error if on_error is not None else self.default_on_error + ) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ["http", "websocket"]: + await self.app(scope, receive, send) + return + + conn = HTTPConnection(scope) + try: + auth_result = await self.backend.authenticate(conn) + except AuthenticationError as exc: + response = self.on_error(conn, exc) + if scope["type"] == "websocket": + await send({"type": "websocket.close", "code": 1000}) + else: + await response(scope, receive, send) + return + + if auth_result is None: + auth_result = AuthCredentials(), UnauthenticatedUser() + scope["auth"], scope["user"] = auth_result + await self.app(scope, receive, send) + + @staticmethod + def default_on_error(conn: HTTPConnection, exc: Exception) -> Response: + return PlainTextResponse(str(exc), status_code=400) diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/base.py b/.venv/lib/python3.12/site-packages/starlette/middleware/base.py new file mode 100644 index 00000000..2a59337e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/base.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import typing + +import anyio + +from starlette._utils import collapse_excgroups +from starlette.requests import ClientDisconnect, Request +from starlette.responses import AsyncContentStream, Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] +DispatchFunction = typing.Callable[[Request, RequestResponseEndpoint], typing.Awaitable[Response]] +T = typing.TypeVar("T") + + +class _CachedRequest(Request): + """ + If the user calls Request.body() from their dispatch function + we cache the entire request body in memory and pass that to downstream middlewares, + but if they call Request.stream() then all we do is send an + empty body so that downstream things don't hang forever. + """ + + def __init__(self, scope: Scope, receive: Receive): + super().__init__(scope, receive) + self._wrapped_rcv_disconnected = False + self._wrapped_rcv_consumed = False + self._wrapped_rc_stream = self.stream() + + async def wrapped_receive(self) -> Message: + # wrapped_rcv state 1: disconnected + if self._wrapped_rcv_disconnected: + # we've already sent a disconnect to the downstream app + # we don't need to wait to get another one + # (although most ASGI servers will just keep sending it) + return {"type": "http.disconnect"} + # wrapped_rcv state 1: consumed but not yet disconnected + if self._wrapped_rcv_consumed: + # since the downstream app has consumed us all that is left + # is to send it a disconnect + if self._is_disconnected: + # the middleware has already seen the disconnect + # since we know the client is disconnected no need to wait + # for the message + self._wrapped_rcv_disconnected = True + return {"type": "http.disconnect"} + # we don't know yet if the client is disconnected or not + # so we'll wait until we get that message + msg = await self.receive() + if msg["type"] != "http.disconnect": # pragma: no cover + # at this point a disconnect is all that we should be receiving + # if we get something else, things went wrong somewhere + raise RuntimeError(f"Unexpected message received: {msg['type']}") + self._wrapped_rcv_disconnected = True + return msg + + # wrapped_rcv state 3: not yet consumed + if getattr(self, "_body", None) is not None: + # body() was called, we return it even if the client disconnected + self._wrapped_rcv_consumed = True + return { + "type": "http.request", + "body": self._body, + "more_body": False, + } + elif self._stream_consumed: + # stream() was called to completion + # return an empty body so that downstream apps don't hang + # waiting for a disconnect + self._wrapped_rcv_consumed = True + return { + "type": "http.request", + "body": b"", + "more_body": False, + } + else: + # body() was never called and stream() wasn't consumed + try: + stream = self.stream() + chunk = await stream.__anext__() + self._wrapped_rcv_consumed = self._stream_consumed + return { + "type": "http.request", + "body": chunk, + "more_body": not self._stream_consumed, + } + except ClientDisconnect: + self._wrapped_rcv_disconnected = True + return {"type": "http.disconnect"} + + +class BaseHTTPMiddleware: + def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None: + self.app = app + self.dispatch_func = self.dispatch if dispatch is None else dispatch + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request = _CachedRequest(scope, receive) + wrapped_receive = request.wrapped_receive + response_sent = anyio.Event() + app_exc: Exception | None = None + + async def call_next(request: Request) -> Response: + async def receive_or_disconnect() -> Message: + if response_sent.is_set(): + return {"type": "http.disconnect"} + + async with anyio.create_task_group() as task_group: + + async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T: + result = await func() + task_group.cancel_scope.cancel() + return result + + task_group.start_soon(wrap, response_sent.wait) + message = await wrap(wrapped_receive) + + if response_sent.is_set(): + return {"type": "http.disconnect"} + + return message + + async def send_no_error(message: Message) -> None: + try: + await send_stream.send(message) + except anyio.BrokenResourceError: + # recv_stream has been closed, i.e. response_sent has been set. + return + + async def coro() -> None: + nonlocal app_exc + + with send_stream: + try: + await self.app(scope, receive_or_disconnect, send_no_error) + except Exception as exc: + app_exc = exc + + task_group.start_soon(coro) + + try: + message = await recv_stream.receive() + info = message.get("info", None) + if message["type"] == "http.response.debug" and info is not None: + message = await recv_stream.receive() + except anyio.EndOfStream: + if app_exc is not None: + raise app_exc + raise RuntimeError("No response returned.") + + assert message["type"] == "http.response.start" + + async def body_stream() -> typing.AsyncGenerator[bytes, None]: + async for message in recv_stream: + assert message["type"] == "http.response.body" + body = message.get("body", b"") + if body: + yield body + if not message.get("more_body", False): + break + + response = _StreamingResponse(status_code=message["status"], content=body_stream(), info=info) + response.raw_headers = message["headers"] + return response + + streams: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream() + send_stream, recv_stream = streams + with recv_stream, send_stream, collapse_excgroups(): + async with anyio.create_task_group() as task_group: + response = await self.dispatch_func(request, call_next) + await response(scope, wrapped_receive, send) + response_sent.set() + recv_stream.close() + + if app_exc is not None: + raise app_exc + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + raise NotImplementedError() # pragma: no cover + + +class _StreamingResponse(Response): + def __init__( + self, + content: AsyncContentStream, + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + info: typing.Mapping[str, typing.Any] | None = None, + ) -> None: + self.info = info + self.body_iterator = content + self.status_code = status_code + self.media_type = media_type + self.init_headers(headers) + self.background = None + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if self.info is not None: + await send({"type": "http.response.debug", "info": self.info}) + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + + async for chunk in self.body_iterator: + await send({"type": "http.response.body", "body": chunk, "more_body": True}) + + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + if self.background: + await self.background() diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py b/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py new file mode 100644 index 00000000..61502691 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/cors.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import functools +import re +import typing + +from starlette.datastructures import Headers, MutableHeaders +from starlette.responses import PlainTextResponse, Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT") +SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"} + + +class CORSMiddleware: + def __init__( + self, + app: ASGIApp, + allow_origins: typing.Sequence[str] = (), + allow_methods: typing.Sequence[str] = ("GET",), + allow_headers: typing.Sequence[str] = (), + allow_credentials: bool = False, + allow_origin_regex: str | None = None, + expose_headers: typing.Sequence[str] = (), + max_age: int = 600, + ) -> None: + if "*" in allow_methods: + allow_methods = ALL_METHODS + + compiled_allow_origin_regex = None + if allow_origin_regex is not None: + compiled_allow_origin_regex = re.compile(allow_origin_regex) + + allow_all_origins = "*" in allow_origins + allow_all_headers = "*" in allow_headers + preflight_explicit_allow_origin = not allow_all_origins or allow_credentials + + simple_headers = {} + if allow_all_origins: + simple_headers["Access-Control-Allow-Origin"] = "*" + if allow_credentials: + simple_headers["Access-Control-Allow-Credentials"] = "true" + if expose_headers: + simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers) + + preflight_headers = {} + if preflight_explicit_allow_origin: + # The origin value will be set in preflight_response() if it is allowed. + preflight_headers["Vary"] = "Origin" + else: + preflight_headers["Access-Control-Allow-Origin"] = "*" + preflight_headers.update( + { + "Access-Control-Allow-Methods": ", ".join(allow_methods), + "Access-Control-Max-Age": str(max_age), + } + ) + allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers)) + if allow_headers and not allow_all_headers: + preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers) + if allow_credentials: + preflight_headers["Access-Control-Allow-Credentials"] = "true" + + self.app = app + self.allow_origins = allow_origins + self.allow_methods = allow_methods + self.allow_headers = [h.lower() for h in allow_headers] + self.allow_all_origins = allow_all_origins + self.allow_all_headers = allow_all_headers + self.preflight_explicit_allow_origin = preflight_explicit_allow_origin + self.allow_origin_regex = compiled_allow_origin_regex + self.simple_headers = simple_headers + self.preflight_headers = preflight_headers + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": # pragma: no cover + await self.app(scope, receive, send) + return + + method = scope["method"] + headers = Headers(scope=scope) + origin = headers.get("origin") + + if origin is None: + await self.app(scope, receive, send) + return + + if method == "OPTIONS" and "access-control-request-method" in headers: + response = self.preflight_response(request_headers=headers) + await response(scope, receive, send) + return + + await self.simple_response(scope, receive, send, request_headers=headers) + + def is_allowed_origin(self, origin: str) -> bool: + if self.allow_all_origins: + return True + + if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(origin): + return True + + return origin in self.allow_origins + + def preflight_response(self, request_headers: Headers) -> Response: + requested_origin = request_headers["origin"] + requested_method = request_headers["access-control-request-method"] + requested_headers = request_headers.get("access-control-request-headers") + + headers = dict(self.preflight_headers) + failures = [] + + if self.is_allowed_origin(origin=requested_origin): + if self.preflight_explicit_allow_origin: + # The "else" case is already accounted for in self.preflight_headers + # and the value would be "*". + headers["Access-Control-Allow-Origin"] = requested_origin + else: + failures.append("origin") + + if requested_method not in self.allow_methods: + failures.append("method") + + # If we allow all headers, then we have to mirror back any requested + # headers in the response. + if self.allow_all_headers and requested_headers is not None: + headers["Access-Control-Allow-Headers"] = requested_headers + elif requested_headers is not None: + for header in [h.lower() for h in requested_headers.split(",")]: + if header.strip() not in self.allow_headers: + failures.append("headers") + break + + # We don't strictly need to use 400 responses here, since its up to + # the browser to enforce the CORS policy, but its more informative + # if we do. + if failures: + failure_text = "Disallowed CORS " + ", ".join(failures) + return PlainTextResponse(failure_text, status_code=400, headers=headers) + + return PlainTextResponse("OK", status_code=200, headers=headers) + + async def simple_response(self, scope: Scope, receive: Receive, send: Send, request_headers: Headers) -> None: + send = functools.partial(self.send, send=send, request_headers=request_headers) + await self.app(scope, receive, send) + + async def send(self, message: Message, send: Send, request_headers: Headers) -> None: + if message["type"] != "http.response.start": + await send(message) + return + + message.setdefault("headers", []) + headers = MutableHeaders(scope=message) + headers.update(self.simple_headers) + origin = request_headers["Origin"] + has_cookie = "cookie" in request_headers + + # If request includes any cookie headers, then we must respond + # with the specific origin instead of '*'. + if self.allow_all_origins and has_cookie: + self.allow_explicit_origin(headers, origin) + + # If we only allow specific origins, then we have to mirror back + # the Origin header in the response. + elif not self.allow_all_origins and self.is_allowed_origin(origin=origin): + self.allow_explicit_origin(headers, origin) + + await send(message) + + @staticmethod + def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None: + headers["Access-Control-Allow-Origin"] = origin + headers.add_vary_header("Origin") diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/errors.py b/.venv/lib/python3.12/site-packages/starlette/middleware/errors.py new file mode 100644 index 00000000..76ad776b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/errors.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +import html +import inspect +import sys +import traceback +import typing + +from starlette._utils import is_async_callable +from starlette.concurrency import run_in_threadpool +from starlette.requests import Request +from starlette.responses import HTMLResponse, PlainTextResponse, Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +STYLES = """ +p { + color: #211c1c; +} +.traceback-container { + border: 1px solid #038BB8; +} +.traceback-title { + background-color: #038BB8; + color: lemonchiffon; + padding: 12px; + font-size: 20px; + margin-top: 0px; +} +.frame-line { + padding-left: 10px; + font-family: monospace; +} +.frame-filename { + font-family: monospace; +} +.center-line { + background-color: #038BB8; + color: #f9f6e1; + padding: 5px 0px 5px 5px; +} +.lineno { + margin-right: 5px; +} +.frame-title { + font-weight: unset; + padding: 10px 10px 10px 10px; + background-color: #E4F4FD; + margin-right: 10px; + color: #191f21; + font-size: 17px; + border: 1px solid #c7dce8; +} +.collapse-btn { + float: right; + padding: 0px 5px 1px 5px; + border: solid 1px #96aebb; + cursor: pointer; +} +.collapsed { + display: none; +} +.source-code { + font-family: courier; + font-size: small; + padding-bottom: 10px; +} +""" + +JS = """ +<script type="text/javascript"> + function collapse(element){ + const frameId = element.getAttribute("data-frame-id"); + const frame = document.getElementById(frameId); + + if (frame.classList.contains("collapsed")){ + element.innerHTML = "‒"; + frame.classList.remove("collapsed"); + } else { + element.innerHTML = "+"; + frame.classList.add("collapsed"); + } + } +</script> +""" + +TEMPLATE = """ +<html> + <head> + <style type='text/css'> + {styles} + </style> + <title>Starlette Debugger</title> + </head> + <body> + <h1>500 Server Error</h1> + <h2>{error}</h2> + <div class="traceback-container"> + <p class="traceback-title">Traceback</p> + <div>{exc_html}</div> + </div> + {js} + </body> +</html> +""" + +FRAME_TEMPLATE = """ +<div> + <p class="frame-title">File <span class="frame-filename">{frame_filename}</span>, + line <i>{frame_lineno}</i>, + in <b>{frame_name}</b> + <span class="collapse-btn" data-frame-id="{frame_filename}-{frame_lineno}" onclick="collapse(this)">{collapse_button}</span> + </p> + <div id="{frame_filename}-{frame_lineno}" class="source-code {collapsed}">{code_context}</div> +</div> +""" # noqa: E501 + +LINE = """ +<p><span class="frame-line"> +<span class="lineno">{lineno}.</span> {line}</span></p> +""" + +CENTER_LINE = """ +<p class="center-line"><span class="frame-line center-line"> +<span class="lineno">{lineno}.</span> {line}</span></p> +""" + + +class ServerErrorMiddleware: + """ + Handles returning 500 responses when a server error occurs. + + If 'debug' is set, then traceback responses will be returned, + otherwise the designated 'handler' will be called. + + This middleware class should generally be used to wrap *everything* + else up, so that unhandled exceptions anywhere in the stack + always result in an appropriate 500 response. + """ + + def __init__( + self, + app: ASGIApp, + handler: typing.Callable[[Request, Exception], typing.Any] | None = None, + debug: bool = False, + ) -> None: + self.app = app + self.handler = handler + self.debug = debug + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + response_started = False + + async def _send(message: Message) -> None: + nonlocal response_started, send + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + try: + await self.app(scope, receive, _send) + except Exception as exc: + request = Request(scope) + if self.debug: + # In debug mode, return traceback responses. + response = self.debug_response(request, exc) + elif self.handler is None: + # Use our default 500 error handler. + response = self.error_response(request, exc) + else: + # Use an installed 500 error handler. + if is_async_callable(self.handler): + response = await self.handler(request, exc) + else: + response = await run_in_threadpool(self.handler, request, exc) + + if not response_started: + await response(scope, receive, send) + + # We always continue to raise the exception. + # This allows servers to log the error, or allows test clients + # to optionally raise the error within the test case. + raise exc + + def format_line(self, index: int, line: str, frame_lineno: int, frame_index: int) -> str: + values = { + # HTML escape - line could contain < or > + "line": html.escape(line).replace(" ", " "), + "lineno": (frame_lineno - frame_index) + index, + } + + if index != frame_index: + return LINE.format(**values) + return CENTER_LINE.format(**values) + + def generate_frame_html(self, frame: inspect.FrameInfo, is_collapsed: bool) -> str: + code_context = "".join( + self.format_line( + index, + line, + frame.lineno, + frame.index, # type: ignore[arg-type] + ) + for index, line in enumerate(frame.code_context or []) + ) + + values = { + # HTML escape - filename could contain < or >, especially if it's a virtual + # file e.g. <stdin> in the REPL + "frame_filename": html.escape(frame.filename), + "frame_lineno": frame.lineno, + # HTML escape - if you try very hard it's possible to name a function with < + # or > + "frame_name": html.escape(frame.function), + "code_context": code_context, + "collapsed": "collapsed" if is_collapsed else "", + "collapse_button": "+" if is_collapsed else "‒", + } + return FRAME_TEMPLATE.format(**values) + + def generate_html(self, exc: Exception, limit: int = 7) -> str: + traceback_obj = traceback.TracebackException.from_exception(exc, capture_locals=True) + + exc_html = "" + is_collapsed = False + exc_traceback = exc.__traceback__ + if exc_traceback is not None: + frames = inspect.getinnerframes(exc_traceback, limit) + for frame in reversed(frames): + exc_html += self.generate_frame_html(frame, is_collapsed) + is_collapsed = True + + if sys.version_info >= (3, 13): # pragma: no cover + exc_type_str = traceback_obj.exc_type_str + else: # pragma: no cover + exc_type_str = traceback_obj.exc_type.__name__ + + # escape error class and text + error = f"{html.escape(exc_type_str)}: {html.escape(str(traceback_obj))}" + + return TEMPLATE.format(styles=STYLES, js=JS, error=error, exc_html=exc_html) + + def generate_plain_text(self, exc: Exception) -> str: + return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) + + def debug_response(self, request: Request, exc: Exception) -> Response: + accept = request.headers.get("accept", "") + + if "text/html" in accept: + content = self.generate_html(exc) + return HTMLResponse(content, status_code=500) + content = self.generate_plain_text(exc) + return PlainTextResponse(content, status_code=500) + + def error_response(self, request: Request, exc: Exception) -> Response: + return PlainTextResponse("Internal Server Error", status_code=500) diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/exceptions.py b/.venv/lib/python3.12/site-packages/starlette/middleware/exceptions.py new file mode 100644 index 00000000..981d2fca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/exceptions.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import typing + +from starlette._exception_handler import ( + ExceptionHandlers, + StatusHandlers, + wrap_app_handling_exceptions, +) +from starlette.exceptions import HTTPException, WebSocketException +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response +from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.websockets import WebSocket + + +class ExceptionMiddleware: + def __init__( + self, + app: ASGIApp, + handlers: typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] | None = None, + debug: bool = False, + ) -> None: + self.app = app + self.debug = debug # TODO: We ought to handle 404 cases if debug is set. + self._status_handlers: StatusHandlers = {} + self._exception_handlers: ExceptionHandlers = { + HTTPException: self.http_exception, + WebSocketException: self.websocket_exception, + } + if handlers is not None: # pragma: no branch + for key, value in handlers.items(): + self.add_exception_handler(key, value) + + def add_exception_handler( + self, + exc_class_or_status_code: int | type[Exception], + handler: typing.Callable[[Request, Exception], Response], + ) -> None: + if isinstance(exc_class_or_status_code, int): + self._status_handlers[exc_class_or_status_code] = handler + else: + assert issubclass(exc_class_or_status_code, Exception) + self._exception_handlers[exc_class_or_status_code] = handler + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ("http", "websocket"): + await self.app(scope, receive, send) + return + + scope["starlette.exception_handlers"] = ( + self._exception_handlers, + self._status_handlers, + ) + + conn: Request | WebSocket + if scope["type"] == "http": + conn = Request(scope, receive, send) + else: + conn = WebSocket(scope, receive, send) + + await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send) + + def http_exception(self, request: Request, exc: Exception) -> Response: + assert isinstance(exc, HTTPException) + if exc.status_code in {204, 304}: + return Response(status_code=exc.status_code, headers=exc.headers) + return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers) + + async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None: + assert isinstance(exc, WebSocketException) + await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/gzip.py b/.venv/lib/python3.12/site-packages/starlette/middleware/gzip.py new file mode 100644 index 00000000..c7fd5b77 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/gzip.py @@ -0,0 +1,141 @@ +import gzip +import io +import typing + +from starlette.datastructures import Headers, MutableHeaders +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",) + + +class GZipMiddleware: + def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None: + self.app = app + self.minimum_size = minimum_size + self.compresslevel = compresslevel + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": # pragma: no cover + await self.app(scope, receive, send) + return + + headers = Headers(scope=scope) + responder: ASGIApp + if "gzip" in headers.get("Accept-Encoding", ""): + responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel) + else: + responder = IdentityResponder(self.app, self.minimum_size) + + await responder(scope, receive, send) + + +class IdentityResponder: + content_encoding: str + + def __init__(self, app: ASGIApp, minimum_size: int) -> None: + self.app = app + self.minimum_size = minimum_size + self.send: Send = unattached_send + self.initial_message: Message = {} + self.started = False + self.content_encoding_set = False + self.content_type_is_excluded = False + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + self.send = send + await self.app(scope, receive, self.send_with_compression) + + async def send_with_compression(self, message: Message) -> None: + message_type = message["type"] + if message_type == "http.response.start": + # Don't send the initial message until we've determined how to + # modify the outgoing headers correctly. + self.initial_message = message + headers = Headers(raw=self.initial_message["headers"]) + self.content_encoding_set = "content-encoding" in headers + self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES) + elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded): + if not self.started: + self.started = True + await self.send(self.initial_message) + await self.send(message) + elif message_type == "http.response.body" and not self.started: + self.started = True + body = message.get("body", b"") + more_body = message.get("more_body", False) + if len(body) < self.minimum_size and not more_body: + # Don't apply compression to small outgoing responses. + await self.send(self.initial_message) + await self.send(message) + elif not more_body: + # Standard response. + body = self.apply_compression(body, more_body=False) + + headers = MutableHeaders(raw=self.initial_message["headers"]) + headers.add_vary_header("Accept-Encoding") + if body != message["body"]: + headers["Content-Encoding"] = self.content_encoding + headers["Content-Length"] = str(len(body)) + message["body"] = body + + await self.send(self.initial_message) + await self.send(message) + else: + # Initial body in streaming response. + body = self.apply_compression(body, more_body=True) + + headers = MutableHeaders(raw=self.initial_message["headers"]) + headers.add_vary_header("Accept-Encoding") + if body != message["body"]: + headers["Content-Encoding"] = self.content_encoding + del headers["Content-Length"] + message["body"] = body + + await self.send(self.initial_message) + await self.send(message) + elif message_type == "http.response.body": # pragma: no branch + # Remaining body in streaming response. + body = message.get("body", b"") + more_body = message.get("more_body", False) + + message["body"] = self.apply_compression(body, more_body=more_body) + + await self.send(message) + + def apply_compression(self, body: bytes, *, more_body: bool) -> bytes: + """Apply compression on the response body. + + If more_body is False, any compression file should be closed. If it + isn't, it won't be closed automatically until all background tasks + complete. + """ + return body + + +class GZipResponder(IdentityResponder): + content_encoding = "gzip" + + def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None: + super().__init__(app, minimum_size) + + self.gzip_buffer = io.BytesIO() + self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + with self.gzip_buffer, self.gzip_file: + await super().__call__(scope, receive, send) + + def apply_compression(self, body: bytes, *, more_body: bool) -> bytes: + self.gzip_file.write(body) + if not more_body: + self.gzip_file.close() + + body = self.gzip_buffer.getvalue() + self.gzip_buffer.seek(0) + self.gzip_buffer.truncate() + + return body + + +async def unattached_send(message: Message) -> typing.NoReturn: + raise RuntimeError("send awaitable not set") # pragma: no cover diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/httpsredirect.py b/.venv/lib/python3.12/site-packages/starlette/middleware/httpsredirect.py new file mode 100644 index 00000000..a8359067 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/httpsredirect.py @@ -0,0 +1,19 @@ +from starlette.datastructures import URL +from starlette.responses import RedirectResponse +from starlette.types import ASGIApp, Receive, Scope, Send + + +class HTTPSRedirectMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] in ("http", "websocket") and scope["scheme"] in ("http", "ws"): + url = URL(scope=scope) + redirect_scheme = {"http": "https", "ws": "wss"}[url.scheme] + netloc = url.hostname if url.port in (80, 443) else url.netloc + url = url.replace(scheme=redirect_scheme, netloc=netloc) + response = RedirectResponse(url, status_code=307) + await response(scope, receive, send) + else: + await self.app(scope, receive, send) diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/sessions.py b/.venv/lib/python3.12/site-packages/starlette/middleware/sessions.py new file mode 100644 index 00000000..5f9fcd88 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/sessions.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import json +import typing +from base64 import b64decode, b64encode + +import itsdangerous +from itsdangerous.exc import BadSignature + +from starlette.datastructures import MutableHeaders, Secret +from starlette.requests import HTTPConnection +from starlette.types import ASGIApp, Message, Receive, Scope, Send + + +class SessionMiddleware: + def __init__( + self, + app: ASGIApp, + secret_key: str | Secret, + session_cookie: str = "session", + max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds + path: str = "/", + same_site: typing.Literal["lax", "strict", "none"] = "lax", + https_only: bool = False, + domain: str | None = None, + ) -> None: + self.app = app + self.signer = itsdangerous.TimestampSigner(str(secret_key)) + self.session_cookie = session_cookie + self.max_age = max_age + self.path = path + self.security_flags = "httponly; samesite=" + same_site + if https_only: # Secure flag can be used with HTTPS only + self.security_flags += "; secure" + if domain is not None: + self.security_flags += f"; domain={domain}" + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ("http", "websocket"): # pragma: no cover + await self.app(scope, receive, send) + return + + connection = HTTPConnection(scope) + initial_session_was_empty = True + + if self.session_cookie in connection.cookies: + data = connection.cookies[self.session_cookie].encode("utf-8") + try: + data = self.signer.unsign(data, max_age=self.max_age) + scope["session"] = json.loads(b64decode(data)) + initial_session_was_empty = False + except BadSignature: + scope["session"] = {} + else: + scope["session"] = {} + + async def send_wrapper(message: Message) -> None: + if message["type"] == "http.response.start": + if scope["session"]: + # We have session data to persist. + data = b64encode(json.dumps(scope["session"]).encode("utf-8")) + data = self.signer.sign(data) + headers = MutableHeaders(scope=message) + header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( + session_cookie=self.session_cookie, + data=data.decode("utf-8"), + path=self.path, + max_age=f"Max-Age={self.max_age}; " if self.max_age else "", + security_flags=self.security_flags, + ) + headers.append("Set-Cookie", header_value) + elif not initial_session_was_empty: + # The session has been cleared. + headers = MutableHeaders(scope=message) + header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( + session_cookie=self.session_cookie, + data="null", + path=self.path, + expires="expires=Thu, 01 Jan 1970 00:00:00 GMT; ", + security_flags=self.security_flags, + ) + headers.append("Set-Cookie", header_value) + await send(message) + + await self.app(scope, receive, send_wrapper) diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py b/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py new file mode 100644 index 00000000..2d1c999e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/trustedhost.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import typing + +from starlette.datastructures import URL, Headers +from starlette.responses import PlainTextResponse, RedirectResponse, Response +from starlette.types import ASGIApp, Receive, Scope, Send + +ENFORCE_DOMAIN_WILDCARD = "Domain wildcard patterns must be like '*.example.com'." + + +class TrustedHostMiddleware: + def __init__( + self, + app: ASGIApp, + allowed_hosts: typing.Sequence[str] | None = None, + www_redirect: bool = True, + ) -> None: + if allowed_hosts is None: + allowed_hosts = ["*"] + + for pattern in allowed_hosts: + assert "*" not in pattern[1:], ENFORCE_DOMAIN_WILDCARD + if pattern.startswith("*") and pattern != "*": + assert pattern.startswith("*."), ENFORCE_DOMAIN_WILDCARD + self.app = app + self.allowed_hosts = list(allowed_hosts) + self.allow_any = "*" in allowed_hosts + self.www_redirect = www_redirect + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if self.allow_any or scope["type"] not in ( + "http", + "websocket", + ): # pragma: no cover + await self.app(scope, receive, send) + return + + headers = Headers(scope=scope) + host = headers.get("host", "").split(":")[0] + is_valid_host = False + found_www_redirect = False + for pattern in self.allowed_hosts: + if host == pattern or (pattern.startswith("*") and host.endswith(pattern[1:])): + is_valid_host = True + break + elif "www." + host == pattern: + found_www_redirect = True + + if is_valid_host: + await self.app(scope, receive, send) + else: + response: Response + if found_www_redirect and self.www_redirect: + url = URL(scope=scope) + redirect_url = url.replace(netloc="www." + url.netloc) + response = RedirectResponse(url=str(redirect_url)) + else: + response = PlainTextResponse("Invalid host header", status_code=400) + await response(scope, receive, send) diff --git a/.venv/lib/python3.12/site-packages/starlette/middleware/wsgi.py b/.venv/lib/python3.12/site-packages/starlette/middleware/wsgi.py new file mode 100644 index 00000000..6e0a3fae --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/middleware/wsgi.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import io +import math +import sys +import typing +import warnings + +import anyio +from anyio.abc import ObjectReceiveStream, ObjectSendStream + +from starlette.types import Receive, Scope, Send + +warnings.warn( + "starlette.middleware.wsgi is deprecated and will be removed in a future release. " + "Please refer to https://github.com/abersheeran/a2wsgi as a replacement.", + DeprecationWarning, +) + + +def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]: + """ + Builds a scope and request body into a WSGI environ object. + """ + + script_name = scope.get("root_path", "").encode("utf8").decode("latin1") + path_info = scope["path"].encode("utf8").decode("latin1") + if path_info.startswith(script_name): + path_info = path_info[len(script_name) :] + + environ = { + "REQUEST_METHOD": scope["method"], + "SCRIPT_NAME": script_name, + "PATH_INFO": path_info, + "QUERY_STRING": scope["query_string"].decode("ascii"), + "SERVER_PROTOCOL": f"HTTP/{scope['http_version']}", + "wsgi.version": (1, 0), + "wsgi.url_scheme": scope.get("scheme", "http"), + "wsgi.input": io.BytesIO(body), + "wsgi.errors": sys.stdout, + "wsgi.multithread": True, + "wsgi.multiprocess": True, + "wsgi.run_once": False, + } + + # Get server name and port - required in WSGI, not in ASGI + server = scope.get("server") or ("localhost", 80) + environ["SERVER_NAME"] = server[0] + environ["SERVER_PORT"] = server[1] + + # Get client IP address + if scope.get("client"): + environ["REMOTE_ADDR"] = scope["client"][0] + + # Go through headers and make them into environ entries + for name, value in scope.get("headers", []): + name = name.decode("latin1") + if name == "content-length": + corrected_name = "CONTENT_LENGTH" + elif name == "content-type": + corrected_name = "CONTENT_TYPE" + else: + corrected_name = f"HTTP_{name}".upper().replace("-", "_") + # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in + # case + value = value.decode("latin1") + if corrected_name in environ: + value = environ[corrected_name] + "," + value + environ[corrected_name] = value + return environ + + +class WSGIMiddleware: + def __init__(self, app: typing.Callable[..., typing.Any]) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + responder = WSGIResponder(self.app, scope) + await responder(receive, send) + + +class WSGIResponder: + stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]] + stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] + + def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None: + self.app = app + self.scope = scope + self.status = None + self.response_headers = None + self.stream_send, self.stream_receive = anyio.create_memory_object_stream(math.inf) + self.response_started = False + self.exc_info: typing.Any = None + + async def __call__(self, receive: Receive, send: Send) -> None: + body = b"" + more_body = True + while more_body: + message = await receive() + body += message.get("body", b"") + more_body = message.get("more_body", False) + environ = build_environ(self.scope, body) + + async with anyio.create_task_group() as task_group: + task_group.start_soon(self.sender, send) + async with self.stream_send: + await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) + if self.exc_info is not None: + raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) + + async def sender(self, send: Send) -> None: + async with self.stream_receive: + async for message in self.stream_receive: + await send(message) + + def start_response( + self, + status: str, + response_headers: list[tuple[str, str]], + exc_info: typing.Any = None, + ) -> None: + self.exc_info = exc_info + if not self.response_started: # pragma: no branch + self.response_started = True + status_code_string, _ = status.split(" ", 1) + status_code = int(status_code_string) + headers = [ + (name.strip().encode("ascii").lower(), value.strip().encode("ascii")) + for name, value in response_headers + ] + anyio.from_thread.run( + self.stream_send.send, + { + "type": "http.response.start", + "status": status_code, + "headers": headers, + }, + ) + + def wsgi( + self, + environ: dict[str, typing.Any], + start_response: typing.Callable[..., typing.Any], + ) -> None: + for chunk in self.app(environ, start_response): + anyio.from_thread.run( + self.stream_send.send, + {"type": "http.response.body", "body": chunk, "more_body": True}, + ) + + anyio.from_thread.run(self.stream_send.send, {"type": "http.response.body", "body": b""}) diff --git a/.venv/lib/python3.12/site-packages/starlette/py.typed b/.venv/lib/python3.12/site-packages/starlette/py.typed new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/py.typed diff --git a/.venv/lib/python3.12/site-packages/starlette/requests.py b/.venv/lib/python3.12/site-packages/starlette/requests.py new file mode 100644 index 00000000..7dc04a74 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/requests.py @@ -0,0 +1,322 @@ +from __future__ import annotations + +import json +import typing +from http import cookies as http_cookies + +import anyio + +from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper +from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State +from starlette.exceptions import HTTPException +from starlette.formparsers import FormParser, MultiPartException, MultiPartParser +from starlette.types import Message, Receive, Scope, Send + +if typing.TYPE_CHECKING: + from python_multipart.multipart import parse_options_header + + from starlette.applications import Starlette + from starlette.routing import Router +else: + try: + try: + from python_multipart.multipart import parse_options_header + except ModuleNotFoundError: # pragma: no cover + from multipart.multipart import parse_options_header + except ModuleNotFoundError: # pragma: no cover + parse_options_header = None + + +SERVER_PUSH_HEADERS_TO_COPY = { + "accept", + "accept-encoding", + "accept-language", + "cache-control", + "user-agent", +} + + +def cookie_parser(cookie_string: str) -> dict[str, str]: + """ + This function parses a ``Cookie`` HTTP header into a dict of key/value pairs. + + It attempts to mimic browser cookie parsing behavior: browsers and web servers + frequently disregard the spec (RFC 6265) when setting and reading cookies, + so we attempt to suit the common scenarios here. + + This function has been adapted from Django 3.1.0. + Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based + on an outdated spec and will fail on lots of input we want to support + """ + cookie_dict: dict[str, str] = {} + for chunk in cookie_string.split(";"): + if "=" in chunk: + key, val = chunk.split("=", 1) + else: + # Assume an empty name per + # https://bugzilla.mozilla.org/show_bug.cgi?id=169091 + key, val = "", chunk + key, val = key.strip(), val.strip() + if key or val: + # unquote using Python's algorithm. + cookie_dict[key] = http_cookies._unquote(val) + return cookie_dict + + +class ClientDisconnect(Exception): + pass + + +class HTTPConnection(typing.Mapping[str, typing.Any]): + """ + A base class for incoming HTTP connections, that is used to provide + any functionality that is common to both `Request` and `WebSocket`. + """ + + def __init__(self, scope: Scope, receive: Receive | None = None) -> None: + assert scope["type"] in ("http", "websocket") + self.scope = scope + + def __getitem__(self, key: str) -> typing.Any: + return self.scope[key] + + def __iter__(self) -> typing.Iterator[str]: + return iter(self.scope) + + def __len__(self) -> int: + return len(self.scope) + + # Don't use the `abc.Mapping.__eq__` implementation. + # Connection instances should never be considered equal + # unless `self is other`. + __eq__ = object.__eq__ + __hash__ = object.__hash__ + + @property + def app(self) -> typing.Any: + return self.scope["app"] + + @property + def url(self) -> URL: + if not hasattr(self, "_url"): # pragma: no branch + self._url = URL(scope=self.scope) + return self._url + + @property + def base_url(self) -> URL: + if not hasattr(self, "_base_url"): + base_url_scope = dict(self.scope) + # This is used by request.url_for, it might be used inside a Mount which + # would have its own child scope with its own root_path, but the base URL + # for url_for should still be the top level app root path. + app_root_path = base_url_scope.get("app_root_path", base_url_scope.get("root_path", "")) + path = app_root_path + if not path.endswith("/"): + path += "/" + base_url_scope["path"] = path + base_url_scope["query_string"] = b"" + base_url_scope["root_path"] = app_root_path + self._base_url = URL(scope=base_url_scope) + return self._base_url + + @property + def headers(self) -> Headers: + if not hasattr(self, "_headers"): + self._headers = Headers(scope=self.scope) + return self._headers + + @property + def query_params(self) -> QueryParams: + if not hasattr(self, "_query_params"): # pragma: no branch + self._query_params = QueryParams(self.scope["query_string"]) + return self._query_params + + @property + def path_params(self) -> dict[str, typing.Any]: + return self.scope.get("path_params", {}) + + @property + def cookies(self) -> dict[str, str]: + if not hasattr(self, "_cookies"): + cookies: dict[str, str] = {} + cookie_header = self.headers.get("cookie") + + if cookie_header: + cookies = cookie_parser(cookie_header) + self._cookies = cookies + return self._cookies + + @property + def client(self) -> Address | None: + # client is a 2 item tuple of (host, port), None if missing + host_port = self.scope.get("client") + if host_port is not None: + return Address(*host_port) + return None + + @property + def session(self) -> dict[str, typing.Any]: + assert "session" in self.scope, "SessionMiddleware must be installed to access request.session" + return self.scope["session"] # type: ignore[no-any-return] + + @property + def auth(self) -> typing.Any: + assert "auth" in self.scope, "AuthenticationMiddleware must be installed to access request.auth" + return self.scope["auth"] + + @property + def user(self) -> typing.Any: + assert "user" in self.scope, "AuthenticationMiddleware must be installed to access request.user" + return self.scope["user"] + + @property + def state(self) -> State: + if not hasattr(self, "_state"): + # Ensure 'state' has an empty dict if it's not already populated. + self.scope.setdefault("state", {}) + # Create a state instance with a reference to the dict in which it should + # store info + self._state = State(self.scope["state"]) + return self._state + + def url_for(self, name: str, /, **path_params: typing.Any) -> URL: + url_path_provider: Router | Starlette | None = self.scope.get("router") or self.scope.get("app") + if url_path_provider is None: + raise RuntimeError("The `url_for` method can only be used inside a Starlette application or with a router.") + url_path = url_path_provider.url_path_for(name, **path_params) + return url_path.make_absolute_url(base_url=self.base_url) + + +async def empty_receive() -> typing.NoReturn: + raise RuntimeError("Receive channel has not been made available") + + +async def empty_send(message: Message) -> typing.NoReturn: + raise RuntimeError("Send channel has not been made available") + + +class Request(HTTPConnection): + _form: FormData | None + + def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send): + super().__init__(scope) + assert scope["type"] == "http" + self._receive = receive + self._send = send + self._stream_consumed = False + self._is_disconnected = False + self._form = None + + @property + def method(self) -> str: + return typing.cast(str, self.scope["method"]) + + @property + def receive(self) -> Receive: + return self._receive + + async def stream(self) -> typing.AsyncGenerator[bytes, None]: + if hasattr(self, "_body"): + yield self._body + yield b"" + return + if self._stream_consumed: + raise RuntimeError("Stream consumed") + while not self._stream_consumed: + message = await self._receive() + if message["type"] == "http.request": + body = message.get("body", b"") + if not message.get("more_body", False): + self._stream_consumed = True + if body: + yield body + elif message["type"] == "http.disconnect": # pragma: no branch + self._is_disconnected = True + raise ClientDisconnect() + yield b"" + + async def body(self) -> bytes: + if not hasattr(self, "_body"): + chunks: list[bytes] = [] + async for chunk in self.stream(): + chunks.append(chunk) + self._body = b"".join(chunks) + return self._body + + async def json(self) -> typing.Any: + if not hasattr(self, "_json"): # pragma: no branch + body = await self.body() + self._json = json.loads(body) + return self._json + + async def _get_form( + self, + *, + max_files: int | float = 1000, + max_fields: int | float = 1000, + max_part_size: int = 1024 * 1024, + ) -> FormData: + if self._form is None: # pragma: no branch + assert parse_options_header is not None, ( + "The `python-multipart` library must be installed to use form parsing." + ) + content_type_header = self.headers.get("Content-Type") + content_type: bytes + content_type, _ = parse_options_header(content_type_header) + if content_type == b"multipart/form-data": + try: + multipart_parser = MultiPartParser( + self.headers, + self.stream(), + max_files=max_files, + max_fields=max_fields, + max_part_size=max_part_size, + ) + self._form = await multipart_parser.parse() + except MultiPartException as exc: + if "app" in self.scope: + raise HTTPException(status_code=400, detail=exc.message) + raise exc + elif content_type == b"application/x-www-form-urlencoded": + form_parser = FormParser(self.headers, self.stream()) + self._form = await form_parser.parse() + else: + self._form = FormData() + return self._form + + def form( + self, + *, + max_files: int | float = 1000, + max_fields: int | float = 1000, + max_part_size: int = 1024 * 1024, + ) -> AwaitableOrContextManager[FormData]: + return AwaitableOrContextManagerWrapper( + self._get_form(max_files=max_files, max_fields=max_fields, max_part_size=max_part_size) + ) + + async def close(self) -> None: + if self._form is not None: # pragma: no branch + await self._form.close() + + async def is_disconnected(self) -> bool: + if not self._is_disconnected: + message: Message = {} + + # If message isn't immediately available, move on + with anyio.CancelScope() as cs: + cs.cancel() + message = await self._receive() + + if message.get("type") == "http.disconnect": + self._is_disconnected = True + + return self._is_disconnected + + async def send_push_promise(self, path: str) -> None: + if "http.response.push" in self.scope.get("extensions", {}): + raw_headers: list[tuple[bytes, bytes]] = [] + for name in SERVER_PUSH_HEADERS_TO_COPY: + for value in self.headers.getlist(name): + raw_headers.append((name.encode("latin-1"), value.encode("latin-1"))) + await self._send({"type": "http.response.push", "path": path, "headers": raw_headers}) diff --git a/.venv/lib/python3.12/site-packages/starlette/responses.py b/.venv/lib/python3.12/site-packages/starlette/responses.py new file mode 100644 index 00000000..81e89fae --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/responses.py @@ -0,0 +1,536 @@ +from __future__ import annotations + +import hashlib +import http.cookies +import json +import os +import re +import stat +import typing +import warnings +from datetime import datetime +from email.utils import format_datetime, formatdate +from functools import partial +from mimetypes import guess_type +from secrets import token_hex +from urllib.parse import quote + +import anyio +import anyio.to_thread + +from starlette._utils import collapse_excgroups +from starlette.background import BackgroundTask +from starlette.concurrency import iterate_in_threadpool +from starlette.datastructures import URL, Headers, MutableHeaders +from starlette.requests import ClientDisconnect +from starlette.types import Receive, Scope, Send + + +class Response: + media_type = None + charset = "utf-8" + + def __init__( + self, + content: typing.Any = None, + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + ) -> None: + self.status_code = status_code + if media_type is not None: + self.media_type = media_type + self.background = background + self.body = self.render(content) + self.init_headers(headers) + + def render(self, content: typing.Any) -> bytes | memoryview: + if content is None: + return b"" + if isinstance(content, (bytes, memoryview)): + return content + return content.encode(self.charset) # type: ignore + + def init_headers(self, headers: typing.Mapping[str, str] | None = None) -> None: + if headers is None: + raw_headers: list[tuple[bytes, bytes]] = [] + populate_content_length = True + populate_content_type = True + else: + raw_headers = [(k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items()] + keys = [h[0] for h in raw_headers] + populate_content_length = b"content-length" not in keys + populate_content_type = b"content-type" not in keys + + body = getattr(self, "body", None) + if ( + body is not None + and populate_content_length + and not (self.status_code < 200 or self.status_code in (204, 304)) + ): + content_length = str(len(body)) + raw_headers.append((b"content-length", content_length.encode("latin-1"))) + + content_type = self.media_type + if content_type is not None and populate_content_type: + if content_type.startswith("text/") and "charset=" not in content_type.lower(): + content_type += "; charset=" + self.charset + raw_headers.append((b"content-type", content_type.encode("latin-1"))) + + self.raw_headers = raw_headers + + @property + def headers(self) -> MutableHeaders: + if not hasattr(self, "_headers"): + self._headers = MutableHeaders(raw=self.raw_headers) + return self._headers + + def set_cookie( + self, + key: str, + value: str = "", + max_age: int | None = None, + expires: datetime | str | int | None = None, + path: str | None = "/", + domain: str | None = None, + secure: bool = False, + httponly: bool = False, + samesite: typing.Literal["lax", "strict", "none"] | None = "lax", + ) -> None: + cookie: http.cookies.BaseCookie[str] = http.cookies.SimpleCookie() + cookie[key] = value + if max_age is not None: + cookie[key]["max-age"] = max_age + if expires is not None: + if isinstance(expires, datetime): + cookie[key]["expires"] = format_datetime(expires, usegmt=True) + else: + cookie[key]["expires"] = expires + if path is not None: + cookie[key]["path"] = path + if domain is not None: + cookie[key]["domain"] = domain + if secure: + cookie[key]["secure"] = True + if httponly: + cookie[key]["httponly"] = True + if samesite is not None: + assert samesite.lower() in [ + "strict", + "lax", + "none", + ], "samesite must be either 'strict', 'lax' or 'none'" + cookie[key]["samesite"] = samesite + cookie_val = cookie.output(header="").strip() + self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1"))) + + def delete_cookie( + self, + key: str, + path: str = "/", + domain: str | None = None, + secure: bool = False, + httponly: bool = False, + samesite: typing.Literal["lax", "strict", "none"] | None = "lax", + ) -> None: + self.set_cookie( + key, + max_age=0, + expires=0, + path=path, + domain=domain, + secure=secure, + httponly=httponly, + samesite=samesite, + ) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + prefix = "websocket." if scope["type"] == "websocket" else "" + await send( + { + "type": prefix + "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + await send({"type": prefix + "http.response.body", "body": self.body}) + + if self.background is not None: + await self.background() + + +class HTMLResponse(Response): + media_type = "text/html" + + +class PlainTextResponse(Response): + media_type = "text/plain" + + +class JSONResponse(Response): + media_type = "application/json" + + def __init__( + self, + content: typing.Any, + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + ) -> None: + super().__init__(content, status_code, headers, media_type, background) + + def render(self, content: typing.Any) -> bytes: + return json.dumps( + content, + ensure_ascii=False, + allow_nan=False, + indent=None, + separators=(",", ":"), + ).encode("utf-8") + + +class RedirectResponse(Response): + def __init__( + self, + url: str | URL, + status_code: int = 307, + headers: typing.Mapping[str, str] | None = None, + background: BackgroundTask | None = None, + ) -> None: + super().__init__(content=b"", status_code=status_code, headers=headers, background=background) + self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;") + + +Content = typing.Union[str, bytes, memoryview] +SyncContentStream = typing.Iterable[Content] +AsyncContentStream = typing.AsyncIterable[Content] +ContentStream = typing.Union[AsyncContentStream, SyncContentStream] + + +class StreamingResponse(Response): + body_iterator: AsyncContentStream + + def __init__( + self, + content: ContentStream, + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + ) -> None: + if isinstance(content, typing.AsyncIterable): + self.body_iterator = content + else: + self.body_iterator = iterate_in_threadpool(content) + self.status_code = status_code + self.media_type = self.media_type if media_type is None else media_type + self.background = background + self.init_headers(headers) + + async def listen_for_disconnect(self, receive: Receive) -> None: + while True: + message = await receive() + if message["type"] == "http.disconnect": + break + + async def stream_response(self, send: Send) -> None: + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + async for chunk in self.body_iterator: + if not isinstance(chunk, (bytes, memoryview)): + chunk = chunk.encode(self.charset) + await send({"type": "http.response.body", "body": chunk, "more_body": True}) + + await send({"type": "http.response.body", "body": b"", "more_body": False}) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split("."))) + + if spec_version >= (2, 4): + try: + await self.stream_response(send) + except OSError: + raise ClientDisconnect() + else: + with collapse_excgroups(): + async with anyio.create_task_group() as task_group: + + async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: + await func() + task_group.cancel_scope.cancel() + + task_group.start_soon(wrap, partial(self.stream_response, send)) + await wrap(partial(self.listen_for_disconnect, receive)) + + if self.background is not None: + await self.background() + + +class MalformedRangeHeader(Exception): + def __init__(self, content: str = "Malformed range header.") -> None: + self.content = content + + +class RangeNotSatisfiable(Exception): + def __init__(self, max_size: int) -> None: + self.max_size = max_size + + +_RANGE_PATTERN = re.compile(r"(\d*)-(\d*)") + + +class FileResponse(Response): + chunk_size = 64 * 1024 + + def __init__( + self, + path: str | os.PathLike[str], + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + filename: str | None = None, + stat_result: os.stat_result | None = None, + method: str | None = None, + content_disposition_type: str = "attachment", + ) -> None: + self.path = path + self.status_code = status_code + self.filename = filename + if method is not None: + warnings.warn( + "The 'method' parameter is not used, and it will be removed.", + DeprecationWarning, + ) + if media_type is None: + media_type = guess_type(filename or path)[0] or "text/plain" + self.media_type = media_type + self.background = background + self.init_headers(headers) + self.headers.setdefault("accept-ranges", "bytes") + if self.filename is not None: + content_disposition_filename = quote(self.filename) + if content_disposition_filename != self.filename: + content_disposition = f"{content_disposition_type}; filename*=utf-8''{content_disposition_filename}" + else: + content_disposition = f'{content_disposition_type}; filename="{self.filename}"' + self.headers.setdefault("content-disposition", content_disposition) + self.stat_result = stat_result + if stat_result is not None: + self.set_stat_headers(stat_result) + + def set_stat_headers(self, stat_result: os.stat_result) -> None: + content_length = str(stat_result.st_size) + last_modified = formatdate(stat_result.st_mtime, usegmt=True) + etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size) + etag = f'"{hashlib.md5(etag_base.encode(), usedforsecurity=False).hexdigest()}"' + + self.headers.setdefault("content-length", content_length) + self.headers.setdefault("last-modified", last_modified) + self.headers.setdefault("etag", etag) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + send_header_only: bool = scope["method"].upper() == "HEAD" + if self.stat_result is None: + try: + stat_result = await anyio.to_thread.run_sync(os.stat, self.path) + self.set_stat_headers(stat_result) + except FileNotFoundError: + raise RuntimeError(f"File at path {self.path} does not exist.") + else: + mode = stat_result.st_mode + if not stat.S_ISREG(mode): + raise RuntimeError(f"File at path {self.path} is not a file.") + else: + stat_result = self.stat_result + + headers = Headers(scope=scope) + http_range = headers.get("range") + http_if_range = headers.get("if-range") + + if http_range is None or (http_if_range is not None and not self._should_use_range(http_if_range)): + await self._handle_simple(send, send_header_only) + else: + try: + ranges = self._parse_range_header(http_range, stat_result.st_size) + except MalformedRangeHeader as exc: + return await PlainTextResponse(exc.content, status_code=400)(scope, receive, send) + except RangeNotSatisfiable as exc: + response = PlainTextResponse(status_code=416, headers={"Content-Range": f"*/{exc.max_size}"}) + return await response(scope, receive, send) + + if len(ranges) == 1: + start, end = ranges[0] + await self._handle_single_range(send, start, end, stat_result.st_size, send_header_only) + else: + await self._handle_multiple_ranges(send, ranges, stat_result.st_size, send_header_only) + + if self.background is not None: + await self.background() + + async def _handle_simple(self, send: Send, send_header_only: bool) -> None: + await send({"type": "http.response.start", "status": self.status_code, "headers": self.raw_headers}) + if send_header_only: + await send({"type": "http.response.body", "body": b"", "more_body": False}) + else: + async with await anyio.open_file(self.path, mode="rb") as file: + more_body = True + while more_body: + chunk = await file.read(self.chunk_size) + more_body = len(chunk) == self.chunk_size + await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) + + async def _handle_single_range( + self, send: Send, start: int, end: int, file_size: int, send_header_only: bool + ) -> None: + self.headers["content-range"] = f"bytes {start}-{end - 1}/{file_size}" + self.headers["content-length"] = str(end - start) + await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers}) + if send_header_only: + await send({"type": "http.response.body", "body": b"", "more_body": False}) + else: + async with await anyio.open_file(self.path, mode="rb") as file: + await file.seek(start) + more_body = True + while more_body: + chunk = await file.read(min(self.chunk_size, end - start)) + start += len(chunk) + more_body = len(chunk) == self.chunk_size and start < end + await send({"type": "http.response.body", "body": chunk, "more_body": more_body}) + + async def _handle_multiple_ranges( + self, + send: Send, + ranges: list[tuple[int, int]], + file_size: int, + send_header_only: bool, + ) -> None: + # In firefox and chrome, they use boundary with 95-96 bits entropy (that's roughly 13 bytes). + boundary = token_hex(13) + content_length, header_generator = self.generate_multipart( + ranges, boundary, file_size, self.headers["content-type"] + ) + self.headers["content-range"] = f"multipart/byteranges; boundary={boundary}" + self.headers["content-length"] = str(content_length) + await send({"type": "http.response.start", "status": 206, "headers": self.raw_headers}) + if send_header_only: + await send({"type": "http.response.body", "body": b"", "more_body": False}) + else: + async with await anyio.open_file(self.path, mode="rb") as file: + for start, end in ranges: + await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True}) + await file.seek(start) + while start < end: + chunk = await file.read(min(self.chunk_size, end - start)) + start += len(chunk) + await send({"type": "http.response.body", "body": chunk, "more_body": True}) + await send({"type": "http.response.body", "body": b"\n", "more_body": True}) + await send( + { + "type": "http.response.body", + "body": f"\n--{boundary}--\n".encode("latin-1"), + "more_body": False, + } + ) + + def _should_use_range(self, http_if_range: str) -> bool: + return http_if_range == self.headers["last-modified"] or http_if_range == self.headers["etag"] + + @staticmethod + def _parse_range_header(http_range: str, file_size: int) -> list[tuple[int, int]]: + ranges: list[tuple[int, int]] = [] + try: + units, range_ = http_range.split("=", 1) + except ValueError: + raise MalformedRangeHeader() + + units = units.strip().lower() + + if units != "bytes": + raise MalformedRangeHeader("Only support bytes range") + + ranges = [ + ( + int(_[0]) if _[0] else file_size - int(_[1]), + int(_[1]) + 1 if _[0] and _[1] and int(_[1]) < file_size else file_size, + ) + for _ in _RANGE_PATTERN.findall(range_) + if _ != ("", "") + ] + + if len(ranges) == 0: + raise MalformedRangeHeader("Range header: range must be requested") + + if any(not (0 <= start < file_size) for start, _ in ranges): + raise RangeNotSatisfiable(file_size) + + if any(start > end for start, end in ranges): + raise MalformedRangeHeader("Range header: start must be less than end") + + if len(ranges) == 1: + return ranges + + # Merge ranges + result: list[tuple[int, int]] = [] + for start, end in ranges: + for p in range(len(result)): + p_start, p_end = result[p] + if start > p_end: + continue + elif end < p_start: + result.insert(p, (start, end)) # THIS IS NOT REACHED! + break + else: + result[p] = (min(start, p_start), max(end, p_end)) + break + else: + result.append((start, end)) + + return result + + def generate_multipart( + self, + ranges: typing.Sequence[tuple[int, int]], + boundary: str, + max_size: int, + content_type: str, + ) -> tuple[int, typing.Callable[[int, int], bytes]]: + r""" + Multipart response headers generator. + + ``` + --{boundary}\n + Content-Type: {content_type}\n + Content-Range: bytes {start}-{end-1}/{max_size}\n + \n + ..........content...........\n + --{boundary}\n + Content-Type: {content_type}\n + Content-Range: bytes {start}-{end-1}/{max_size}\n + \n + ..........content...........\n + --{boundary}--\n + ``` + """ + boundary_len = len(boundary) + static_header_part_len = 44 + boundary_len + len(content_type) + len(str(max_size)) + content_length = sum( + (len(str(start)) + len(str(end - 1)) + static_header_part_len) # Headers + + (end - start) # Content + for start, end in ranges + ) + ( + 5 + boundary_len # --boundary--\n + ) + return ( + content_length, + lambda start, end: ( + f"--{boundary}\nContent-Type: {content_type}\nContent-Range: bytes {start}-{end - 1}/{max_size}\n\n" + ).encode("latin-1"), + ) diff --git a/.venv/lib/python3.12/site-packages/starlette/routing.py b/.venv/lib/python3.12/site-packages/starlette/routing.py new file mode 100644 index 00000000..add7df0c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/routing.py @@ -0,0 +1,874 @@ +from __future__ import annotations + +import contextlib +import functools +import inspect +import re +import traceback +import types +import typing +import warnings +from contextlib import asynccontextmanager +from enum import Enum + +from starlette._exception_handler import wrap_app_handling_exceptions +from starlette._utils import get_route_path, is_async_callable +from starlette.concurrency import run_in_threadpool +from starlette.convertors import CONVERTOR_TYPES, Convertor +from starlette.datastructures import URL, Headers, URLPath +from starlette.exceptions import HTTPException +from starlette.middleware import Middleware +from starlette.requests import Request +from starlette.responses import PlainTextResponse, RedirectResponse, Response +from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send +from starlette.websockets import WebSocket, WebSocketClose + + +class NoMatchFound(Exception): + """ + Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)` + if no matching route exists. + """ + + def __init__(self, name: str, path_params: dict[str, typing.Any]) -> None: + params = ", ".join(list(path_params.keys())) + super().__init__(f'No route exists for name "{name}" and params "{params}".') + + +class Match(Enum): + NONE = 0 + PARTIAL = 1 + FULL = 2 + + +def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover + """ + Correctly determines if an object is a coroutine function, + including those wrapped in functools.partial objects. + """ + warnings.warn( + "iscoroutinefunction_or_partial is deprecated, and will be removed in a future release.", + DeprecationWarning, + ) + while isinstance(obj, functools.partial): + obj = obj.func + return inspect.iscoroutinefunction(obj) + + +def request_response( + func: typing.Callable[[Request], typing.Awaitable[Response] | Response], +) -> ASGIApp: + """ + Takes a function or coroutine `func(request) -> response`, + and returns an ASGI application. + """ + f: typing.Callable[[Request], typing.Awaitable[Response]] = ( + func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore + ) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + request = Request(scope, receive, send) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + response = await f(request) + await response(scope, receive, send) + + await wrap_app_handling_exceptions(app, request)(scope, receive, send) + + return app + + +def websocket_session( + func: typing.Callable[[WebSocket], typing.Awaitable[None]], +) -> ASGIApp: + """ + Takes a coroutine `func(session)`, and returns an ASGI application. + """ + # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async" + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + session = WebSocket(scope, receive=receive, send=send) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await func(session) + + await wrap_app_handling_exceptions(app, session)(scope, receive, send) + + return app + + +def get_name(endpoint: typing.Callable[..., typing.Any]) -> str: + return getattr(endpoint, "__name__", endpoint.__class__.__name__) + + +def replace_params( + path: str, + param_convertors: dict[str, Convertor[typing.Any]], + path_params: dict[str, str], +) -> tuple[str, dict[str, str]]: + for key, value in list(path_params.items()): + if "{" + key + "}" in path: + convertor = param_convertors[key] + value = convertor.to_string(value) + path = path.replace("{" + key + "}", value) + path_params.pop(key) + return path, path_params + + +# Match parameters in URL paths, eg. '{param}', and '{param:int}' +PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}") + + +def compile_path( + path: str, +) -> tuple[typing.Pattern[str], str, dict[str, Convertor[typing.Any]]]: + """ + Given a path string, like: "/{username:str}", + or a host string, like: "{subdomain}.mydomain.org", return a three-tuple + of (regex, format, {param_name:convertor}). + + regex: "/(?P<username>[^/]+)" + format: "/{username}" + convertors: {"username": StringConvertor()} + """ + is_host = not path.startswith("/") + + path_regex = "^" + path_format = "" + duplicated_params = set() + + idx = 0 + param_convertors = {} + for match in PARAM_REGEX.finditer(path): + param_name, convertor_type = match.groups("str") + convertor_type = convertor_type.lstrip(":") + assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'" + convertor = CONVERTOR_TYPES[convertor_type] + + path_regex += re.escape(path[idx : match.start()]) + path_regex += f"(?P<{param_name}>{convertor.regex})" + + path_format += path[idx : match.start()] + path_format += "{%s}" % param_name + + if param_name in param_convertors: + duplicated_params.add(param_name) + + param_convertors[param_name] = convertor + + idx = match.end() + + if duplicated_params: + names = ", ".join(sorted(duplicated_params)) + ending = "s" if len(duplicated_params) > 1 else "" + raise ValueError(f"Duplicated param name{ending} {names} at path {path}") + + if is_host: + # Align with `Host.matches()` behavior, which ignores port. + hostname = path[idx:].split(":")[0] + path_regex += re.escape(hostname) + "$" + else: + path_regex += re.escape(path[idx:]) + "$" + + path_format += path[idx:] + + return re.compile(path_regex), path_format, param_convertors + + +class BaseRoute: + def matches(self, scope: Scope) -> tuple[Match, Scope]: + raise NotImplementedError() # pragma: no cover + + def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + raise NotImplementedError() # pragma: no cover + + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + raise NotImplementedError() # pragma: no cover + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ + A route may be used in isolation as a stand-alone ASGI app. + This is a somewhat contrived case, as they'll almost always be used + within a Router, but could be useful for some tooling and minimal apps. + """ + match, child_scope = self.matches(scope) + if match == Match.NONE: + if scope["type"] == "http": + response = PlainTextResponse("Not Found", status_code=404) + await response(scope, receive, send) + elif scope["type"] == "websocket": # pragma: no branch + websocket_close = WebSocketClose() + await websocket_close(scope, receive, send) + return + + scope.update(child_scope) + await self.handle(scope, receive, send) + + +class Route(BaseRoute): + def __init__( + self, + path: str, + endpoint: typing.Callable[..., typing.Any], + *, + methods: list[str] | None = None, + name: str | None = None, + include_in_schema: bool = True, + middleware: typing.Sequence[Middleware] | None = None, + ) -> None: + assert path.startswith("/"), "Routed paths must start with '/'" + self.path = path + self.endpoint = endpoint + self.name = get_name(endpoint) if name is None else name + self.include_in_schema = include_in_schema + + endpoint_handler = endpoint + while isinstance(endpoint_handler, functools.partial): + endpoint_handler = endpoint_handler.func + if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): + # Endpoint is function or method. Treat it as `func(request) -> response`. + self.app = request_response(endpoint) + if methods is None: + methods = ["GET"] + else: + # Endpoint is a class. Treat it as ASGI. + self.app = endpoint + + if middleware is not None: + for cls, args, kwargs in reversed(middleware): + self.app = cls(self.app, *args, **kwargs) + + if methods is None: + self.methods = None + else: + self.methods = {method.upper() for method in methods} + if "GET" in self.methods: + self.methods.add("HEAD") + + self.path_regex, self.path_format, self.param_convertors = compile_path(path) + + def matches(self, scope: Scope) -> tuple[Match, Scope]: + path_params: dict[str, typing.Any] + if scope["type"] == "http": + route_path = get_route_path(scope) + match = self.path_regex.match(route_path) + if match: + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key].convert(value) + path_params = dict(scope.get("path_params", {})) + path_params.update(matched_params) + child_scope = {"endpoint": self.endpoint, "path_params": path_params} + if self.methods and scope["method"] not in self.methods: + return Match.PARTIAL, child_scope + else: + return Match.FULL, child_scope + return Match.NONE, {} + + def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + seen_params = set(path_params.keys()) + expected_params = set(self.param_convertors.keys()) + + if name != self.name or seen_params != expected_params: + raise NoMatchFound(name, path_params) + + path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) + assert not remaining_params + return URLPath(path=path, protocol="http") + + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + if self.methods and scope["method"] not in self.methods: + headers = {"Allow": ", ".join(self.methods)} + if "app" in scope: + raise HTTPException(status_code=405, headers=headers) + else: + response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers) + await response(scope, receive, send) + else: + await self.app(scope, receive, send) + + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, Route) + and self.path == other.path + and self.endpoint == other.endpoint + and self.methods == other.methods + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + methods = sorted(self.methods or []) + path, name = self.path, self.name + return f"{class_name}(path={path!r}, name={name!r}, methods={methods!r})" + + +class WebSocketRoute(BaseRoute): + def __init__( + self, + path: str, + endpoint: typing.Callable[..., typing.Any], + *, + name: str | None = None, + middleware: typing.Sequence[Middleware] | None = None, + ) -> None: + assert path.startswith("/"), "Routed paths must start with '/'" + self.path = path + self.endpoint = endpoint + self.name = get_name(endpoint) if name is None else name + + endpoint_handler = endpoint + while isinstance(endpoint_handler, functools.partial): + endpoint_handler = endpoint_handler.func + if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): + # Endpoint is function or method. Treat it as `func(websocket)`. + self.app = websocket_session(endpoint) + else: + # Endpoint is a class. Treat it as ASGI. + self.app = endpoint + + if middleware is not None: + for cls, args, kwargs in reversed(middleware): + self.app = cls(self.app, *args, **kwargs) + + self.path_regex, self.path_format, self.param_convertors = compile_path(path) + + def matches(self, scope: Scope) -> tuple[Match, Scope]: + path_params: dict[str, typing.Any] + if scope["type"] == "websocket": + route_path = get_route_path(scope) + match = self.path_regex.match(route_path) + if match: + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key].convert(value) + path_params = dict(scope.get("path_params", {})) + path_params.update(matched_params) + child_scope = {"endpoint": self.endpoint, "path_params": path_params} + return Match.FULL, child_scope + return Match.NONE, {} + + def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + seen_params = set(path_params.keys()) + expected_params = set(self.param_convertors.keys()) + + if name != self.name or seen_params != expected_params: + raise NoMatchFound(name, path_params) + + path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) + assert not remaining_params + return URLPath(path=path, protocol="websocket") + + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + await self.app(scope, receive, send) + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})" + + +class Mount(BaseRoute): + def __init__( + self, + path: str, + app: ASGIApp | None = None, + routes: typing.Sequence[BaseRoute] | None = None, + name: str | None = None, + *, + middleware: typing.Sequence[Middleware] | None = None, + ) -> None: + assert path == "" or path.startswith("/"), "Routed paths must start with '/'" + assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified" + self.path = path.rstrip("/") + if app is not None: + self._base_app: ASGIApp = app + else: + self._base_app = Router(routes=routes) + self.app = self._base_app + if middleware is not None: + for cls, args, kwargs in reversed(middleware): + self.app = cls(self.app, *args, **kwargs) + self.name = name + self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}") + + @property + def routes(self) -> list[BaseRoute]: + return getattr(self._base_app, "routes", []) + + def matches(self, scope: Scope) -> tuple[Match, Scope]: + path_params: dict[str, typing.Any] + if scope["type"] in ("http", "websocket"): # pragma: no branch + root_path = scope.get("root_path", "") + route_path = get_route_path(scope) + match = self.path_regex.match(route_path) + if match: + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key].convert(value) + remaining_path = "/" + matched_params.pop("path") + matched_path = route_path[: -len(remaining_path)] + path_params = dict(scope.get("path_params", {})) + path_params.update(matched_params) + child_scope = { + "path_params": path_params, + # app_root_path will only be set at the top level scope, + # initialized with the (optional) value of a root_path + # set above/before Starlette. And even though any + # mount will have its own child scope with its own respective + # root_path, the app_root_path will always be available in all + # the child scopes with the same top level value because it's + # set only once here with a default, any other child scope will + # just inherit that app_root_path default value stored in the + # scope. All this is needed to support Request.url_for(), as it + # uses the app_root_path to build the URL path. + "app_root_path": scope.get("app_root_path", root_path), + "root_path": root_path + matched_path, + "endpoint": self.app, + } + return Match.FULL, child_scope + return Match.NONE, {} + + def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + if self.name is not None and name == self.name and "path" in path_params: + # 'name' matches "<mount_name>". + path_params["path"] = path_params["path"].lstrip("/") + path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) + if not remaining_params: + return URLPath(path=path) + elif self.name is None or name.startswith(self.name + ":"): + if self.name is None: + # No mount name. + remaining_name = name + else: + # 'name' matches "<mount_name>:<child_name>". + remaining_name = name[len(self.name) + 1 :] + path_kwarg = path_params.get("path") + path_params["path"] = "" + path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params) + if path_kwarg is not None: + remaining_params["path"] = path_kwarg + for route in self.routes or []: + try: + url = route.url_path_for(remaining_name, **remaining_params) + return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol) + except NoMatchFound: + pass + raise NoMatchFound(name, path_params) + + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + await self.app(scope, receive, send) + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, Mount) and self.path == other.path and self.app == other.app + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + name = self.name or "" + return f"{class_name}(path={self.path!r}, name={name!r}, app={self.app!r})" + + +class Host(BaseRoute): + def __init__(self, host: str, app: ASGIApp, name: str | None = None) -> None: + assert not host.startswith("/"), "Host must not start with '/'" + self.host = host + self.app = app + self.name = name + self.host_regex, self.host_format, self.param_convertors = compile_path(host) + + @property + def routes(self) -> list[BaseRoute]: + return getattr(self.app, "routes", []) + + def matches(self, scope: Scope) -> tuple[Match, Scope]: + if scope["type"] in ("http", "websocket"): # pragma:no branch + headers = Headers(scope=scope) + host = headers.get("host", "").split(":")[0] + match = self.host_regex.match(host) + if match: + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key].convert(value) + path_params = dict(scope.get("path_params", {})) + path_params.update(matched_params) + child_scope = {"path_params": path_params, "endpoint": self.app} + return Match.FULL, child_scope + return Match.NONE, {} + + def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + if self.name is not None and name == self.name and "path" in path_params: + # 'name' matches "<mount_name>". + path = path_params.pop("path") + host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params) + if not remaining_params: + return URLPath(path=path, host=host) + elif self.name is None or name.startswith(self.name + ":"): + if self.name is None: + # No mount name. + remaining_name = name + else: + # 'name' matches "<mount_name>:<child_name>". + remaining_name = name[len(self.name) + 1 :] + host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params) + for route in self.routes or []: + try: + url = route.url_path_for(remaining_name, **remaining_params) + return URLPath(path=str(url), protocol=url.protocol, host=host) + except NoMatchFound: + pass + raise NoMatchFound(name, path_params) + + async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: + await self.app(scope, receive, send) + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, Host) and self.host == other.host and self.app == other.app + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + name = self.name or "" + return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})" + + +_T = typing.TypeVar("_T") + + +class _AsyncLiftContextManager(typing.AsyncContextManager[_T]): + def __init__(self, cm: typing.ContextManager[_T]): + self._cm = cm + + async def __aenter__(self) -> _T: + return self._cm.__enter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> bool | None: + return self._cm.__exit__(exc_type, exc_value, traceback) + + +def _wrap_gen_lifespan_context( + lifespan_context: typing.Callable[[typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]], +) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]: + cmgr = contextlib.contextmanager(lifespan_context) + + @functools.wraps(cmgr) + def wrapper(app: typing.Any) -> _AsyncLiftContextManager[typing.Any]: + return _AsyncLiftContextManager(cmgr(app)) + + return wrapper + + +class _DefaultLifespan: + def __init__(self, router: Router): + self._router = router + + async def __aenter__(self) -> None: + await self._router.startup() + + async def __aexit__(self, *exc_info: object) -> None: + await self._router.shutdown() + + def __call__(self: _T, app: object) -> _T: + return self + + +class Router: + def __init__( + self, + routes: typing.Sequence[BaseRoute] | None = None, + redirect_slashes: bool = True, + default: ASGIApp | None = None, + on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, + on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, + # the generic to Lifespan[AppType] is the type of the top level application + # which the router cannot know statically, so we use typing.Any + lifespan: Lifespan[typing.Any] | None = None, + *, + middleware: typing.Sequence[Middleware] | None = None, + ) -> None: + self.routes = [] if routes is None else list(routes) + self.redirect_slashes = redirect_slashes + self.default = self.not_found if default is None else default + self.on_startup = [] if on_startup is None else list(on_startup) + self.on_shutdown = [] if on_shutdown is None else list(on_shutdown) + + if on_startup or on_shutdown: + warnings.warn( + "The on_startup and on_shutdown parameters are deprecated, and they " + "will be removed on version 1.0. Use the lifespan parameter instead. " + "See more about it on https://www.starlette.io/lifespan/.", + DeprecationWarning, + ) + if lifespan: + warnings.warn( + "The `lifespan` parameter cannot be used with `on_startup` or " + "`on_shutdown`. Both `on_startup` and `on_shutdown` will be " + "ignored." + ) + + if lifespan is None: + self.lifespan_context: Lifespan[typing.Any] = _DefaultLifespan(self) + + elif inspect.isasyncgenfunction(lifespan): + warnings.warn( + "async generator function lifespans are deprecated, " + "use an @contextlib.asynccontextmanager function instead", + DeprecationWarning, + ) + self.lifespan_context = asynccontextmanager( + lifespan, + ) + elif inspect.isgeneratorfunction(lifespan): + warnings.warn( + "generator function lifespans are deprecated, use an @contextlib.asynccontextmanager function instead", + DeprecationWarning, + ) + self.lifespan_context = _wrap_gen_lifespan_context( + lifespan, + ) + else: + self.lifespan_context = lifespan + + self.middleware_stack = self.app + if middleware: + for cls, args, kwargs in reversed(middleware): + self.middleware_stack = cls(self.middleware_stack, *args, **kwargs) + + async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "websocket": + websocket_close = WebSocketClose() + await websocket_close(scope, receive, send) + return + + # If we're running inside a starlette application then raise an + # exception, so that the configurable exception handler can deal with + # returning the response. For plain ASGI apps, just return the response. + if "app" in scope: + raise HTTPException(status_code=404) + else: + response = PlainTextResponse("Not Found", status_code=404) + await response(scope, receive, send) + + def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: + for route in self.routes: + try: + return route.url_path_for(name, **path_params) + except NoMatchFound: + pass + raise NoMatchFound(name, path_params) + + async def startup(self) -> None: + """ + Run any `.on_startup` event handlers. + """ + for handler in self.on_startup: + if is_async_callable(handler): + await handler() + else: + handler() + + async def shutdown(self) -> None: + """ + Run any `.on_shutdown` event handlers. + """ + for handler in self.on_shutdown: + if is_async_callable(handler): + await handler() + else: + handler() + + async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: + """ + Handle ASGI lifespan messages, which allows us to manage application + startup and shutdown events. + """ + started = False + app: typing.Any = scope.get("app") + await receive() + try: + async with self.lifespan_context(app) as maybe_state: + if maybe_state is not None: + if "state" not in scope: + raise RuntimeError('The server does not support "state" in the lifespan scope.') + scope["state"].update(maybe_state) + await send({"type": "lifespan.startup.complete"}) + started = True + await receive() + except BaseException: + exc_text = traceback.format_exc() + if started: + await send({"type": "lifespan.shutdown.failed", "message": exc_text}) + else: + await send({"type": "lifespan.startup.failed", "message": exc_text}) + raise + else: + await send({"type": "lifespan.shutdown.complete"}) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ + The main entry point to the Router class. + """ + await self.middleware_stack(scope, receive, send) + + async def app(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] in ("http", "websocket", "lifespan") + + if "router" not in scope: + scope["router"] = self + + if scope["type"] == "lifespan": + await self.lifespan(scope, receive, send) + return + + partial = None + + for route in self.routes: + # Determine if any route matches the incoming scope, + # and hand over to the matching route if found. + match, child_scope = route.matches(scope) + if match == Match.FULL: + scope.update(child_scope) + await route.handle(scope, receive, send) + return + elif match == Match.PARTIAL and partial is None: + partial = route + partial_scope = child_scope + + if partial is not None: + # Â Handle partial matches. These are cases where an endpoint is + # able to handle the request, but is not a preferred option. + # We use this in particular to deal with "405 Method Not Allowed". + scope.update(partial_scope) + await partial.handle(scope, receive, send) + return + + route_path = get_route_path(scope) + if scope["type"] == "http" and self.redirect_slashes and route_path != "/": + redirect_scope = dict(scope) + if route_path.endswith("/"): + redirect_scope["path"] = redirect_scope["path"].rstrip("/") + else: + redirect_scope["path"] = redirect_scope["path"] + "/" + + for route in self.routes: + match, child_scope = route.matches(redirect_scope) + if match != Match.NONE: + redirect_url = URL(scope=redirect_scope) + response = RedirectResponse(url=str(redirect_url)) + await response(scope, receive, send) + return + + await self.default(scope, receive, send) + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, Router) and self.routes == other.routes + + def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover + route = Mount(path, app=app, name=name) + self.routes.append(route) + + def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover + route = Host(host, app=app, name=name) + self.routes.append(route) + + def add_route( + self, + path: str, + endpoint: typing.Callable[[Request], typing.Awaitable[Response] | Response], + methods: list[str] | None = None, + name: str | None = None, + include_in_schema: bool = True, + ) -> None: # pragma: no cover + route = Route( + path, + endpoint=endpoint, + methods=methods, + name=name, + include_in_schema=include_in_schema, + ) + self.routes.append(route) + + def add_websocket_route( + self, + path: str, + endpoint: typing.Callable[[WebSocket], typing.Awaitable[None]], + name: str | None = None, + ) -> None: # pragma: no cover + route = WebSocketRoute(path, endpoint=endpoint, name=name) + self.routes.append(route) + + def route( + self, + path: str, + methods: list[str] | None = None, + name: str | None = None, + include_in_schema: bool = True, + ) -> typing.Callable: # type: ignore[type-arg] + """ + We no longer document this decorator style API, and its usage is discouraged. + Instead you should use the following approach: + + >>> routes = [Route(path, endpoint=...), ...] + >>> app = Starlette(routes=routes) + """ + warnings.warn( + "The `route` decorator is deprecated, and will be removed in version 1.0.0." + "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.", + DeprecationWarning, + ) + + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + self.add_route( + path, + func, + methods=methods, + name=name, + include_in_schema=include_in_schema, + ) + return func + + return decorator + + def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg] + """ + We no longer document this decorator style API, and its usage is discouraged. + Instead you should use the following approach: + + >>> routes = [WebSocketRoute(path, endpoint=...), ...] + >>> app = Starlette(routes=routes) + """ + warnings.warn( + "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to " + "https://www.starlette.io/routing/#websocket-routing for the recommended approach.", + DeprecationWarning, + ) + + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + self.add_websocket_route(path, func, name=name) + return func + + return decorator + + def add_event_handler(self, event_type: str, func: typing.Callable[[], typing.Any]) -> None: # pragma: no cover + assert event_type in ("startup", "shutdown") + + if event_type == "startup": + self.on_startup.append(func) + else: + self.on_shutdown.append(func) + + def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg] + warnings.warn( + "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " + "Refer to https://www.starlette.io/lifespan/ for recommended approach.", + DeprecationWarning, + ) + + def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] + self.add_event_handler(event_type, func) + return func + + return decorator diff --git a/.venv/lib/python3.12/site-packages/starlette/schemas.py b/.venv/lib/python3.12/site-packages/starlette/schemas.py new file mode 100644 index 00000000..bfc40e2a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/schemas.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import inspect +import re +import typing + +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import BaseRoute, Host, Mount, Route + +try: + import yaml +except ModuleNotFoundError: # pragma: no cover + yaml = None # type: ignore[assignment] + + +class OpenAPIResponse(Response): + media_type = "application/vnd.oai.openapi" + + def render(self, content: typing.Any) -> bytes: + assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse." + assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary." + return yaml.dump(content, default_flow_style=False).encode("utf-8") + + +class EndpointInfo(typing.NamedTuple): + path: str + http_method: str + func: typing.Callable[..., typing.Any] + + +_remove_converter_pattern = re.compile(r":\w+}") + + +class BaseSchemaGenerator: + def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]: + raise NotImplementedError() # pragma: no cover + + def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]: + """ + Given the routes, yields the following information: + + - path + eg: /users/ + - http_method + one of 'get', 'post', 'put', 'patch', 'delete', 'options' + - func + method ready to extract the docstring + """ + endpoints_info: list[EndpointInfo] = [] + + for route in routes: + if isinstance(route, (Mount, Host)): + routes = route.routes or [] + if isinstance(route, Mount): + path = self._remove_converter(route.path) + else: + path = "" + sub_endpoints = [ + EndpointInfo( + path="".join((path, sub_endpoint.path)), + http_method=sub_endpoint.http_method, + func=sub_endpoint.func, + ) + for sub_endpoint in self.get_endpoints(routes) + ] + endpoints_info.extend(sub_endpoints) + + elif not isinstance(route, Route) or not route.include_in_schema: + continue + + elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint): + path = self._remove_converter(route.path) + for method in route.methods or ["GET"]: + if method == "HEAD": + continue + endpoints_info.append(EndpointInfo(path, method.lower(), route.endpoint)) + else: + path = self._remove_converter(route.path) + for method in ["get", "post", "put", "patch", "delete", "options"]: + if not hasattr(route.endpoint, method): + continue + func = getattr(route.endpoint, method) + endpoints_info.append(EndpointInfo(path, method.lower(), func)) + + return endpoints_info + + def _remove_converter(self, path: str) -> str: + """ + Remove the converter from the path. + For example, a route like this: + Route("/users/{id:int}", endpoint=get_user, methods=["GET"]) + Should be represented as `/users/{id}` in the OpenAPI schema. + """ + return _remove_converter_pattern.sub("}", path) + + def parse_docstring(self, func_or_method: typing.Callable[..., typing.Any]) -> dict[str, typing.Any]: + """ + Given a function, parse the docstring as YAML and return a dictionary of info. + """ + docstring = func_or_method.__doc__ + if not docstring: + return {} + + assert yaml is not None, "`pyyaml` must be installed to use parse_docstring." + + # We support having regular docstrings before the schema + # definition. Here we return just the schema part from + # the docstring. + docstring = docstring.split("---")[-1] + + parsed = yaml.safe_load(docstring) + + if not isinstance(parsed, dict): + # A regular docstring (not yaml formatted) can return + # a simple string here, which wouldn't follow the schema. + return {} + + return parsed + + def OpenAPIResponse(self, request: Request) -> Response: + routes = request.app.routes + schema = self.get_schema(routes=routes) + return OpenAPIResponse(schema) + + +class SchemaGenerator(BaseSchemaGenerator): + def __init__(self, base_schema: dict[str, typing.Any]) -> None: + self.base_schema = base_schema + + def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]: + schema = dict(self.base_schema) + schema.setdefault("paths", {}) + endpoints_info = self.get_endpoints(routes) + + for endpoint in endpoints_info: + parsed = self.parse_docstring(endpoint.func) + + if not parsed: + continue + + if endpoint.path not in schema["paths"]: + schema["paths"][endpoint.path] = {} + + schema["paths"][endpoint.path][endpoint.http_method] = parsed + + return schema diff --git a/.venv/lib/python3.12/site-packages/starlette/staticfiles.py b/.venv/lib/python3.12/site-packages/starlette/staticfiles.py new file mode 100644 index 00000000..637da648 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/staticfiles.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import errno +import importlib.util +import os +import stat +import typing +from email.utils import parsedate + +import anyio +import anyio.to_thread + +from starlette._utils import get_route_path +from starlette.datastructures import URL, Headers +from starlette.exceptions import HTTPException +from starlette.responses import FileResponse, RedirectResponse, Response +from starlette.types import Receive, Scope, Send + +PathLike = typing.Union[str, "os.PathLike[str]"] + + +class NotModifiedResponse(Response): + NOT_MODIFIED_HEADERS = ( + "cache-control", + "content-location", + "date", + "etag", + "expires", + "vary", + ) + + def __init__(self, headers: Headers): + super().__init__( + status_code=304, + headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS}, + ) + + +class StaticFiles: + def __init__( + self, + *, + directory: PathLike | None = None, + packages: list[str | tuple[str, str]] | None = None, + html: bool = False, + check_dir: bool = True, + follow_symlink: bool = False, + ) -> None: + self.directory = directory + self.packages = packages + self.all_directories = self.get_directories(directory, packages) + self.html = html + self.config_checked = False + self.follow_symlink = follow_symlink + if check_dir and directory is not None and not os.path.isdir(directory): + raise RuntimeError(f"Directory '{directory}' does not exist") + + def get_directories( + self, + directory: PathLike | None = None, + packages: list[str | tuple[str, str]] | None = None, + ) -> list[PathLike]: + """ + Given `directory` and `packages` arguments, return a list of all the + directories that should be used for serving static files from. + """ + directories = [] + if directory is not None: + directories.append(directory) + + for package in packages or []: + if isinstance(package, tuple): + package, statics_dir = package + else: + statics_dir = "statics" + spec = importlib.util.find_spec(package) + assert spec is not None, f"Package {package!r} could not be found." + assert spec.origin is not None, f"Package {package!r} could not be found." + package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir)) + assert os.path.isdir(package_directory), ( + f"Directory '{statics_dir!r}' in package {package!r} could not be found." + ) + directories.append(package_directory) + + return directories + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """ + The ASGI entry point. + """ + assert scope["type"] == "http" + + if not self.config_checked: + await self.check_config() + self.config_checked = True + + path = self.get_path(scope) + response = await self.get_response(path, scope) + await response(scope, receive, send) + + def get_path(self, scope: Scope) -> str: + """ + Given the ASGI scope, return the `path` string to serve up, + with OS specific path separators, and any '..', '.' components removed. + """ + route_path = get_route_path(scope) + return os.path.normpath(os.path.join(*route_path.split("/"))) + + async def get_response(self, path: str, scope: Scope) -> Response: + """ + Returns an HTTP response, given the incoming path, method and request headers. + """ + if scope["method"] not in ("GET", "HEAD"): + raise HTTPException(status_code=405) + + try: + full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path) + except PermissionError: + raise HTTPException(status_code=401) + except OSError as exc: + # Filename is too long, so it can't be a valid static file. + if exc.errno == errno.ENAMETOOLONG: + raise HTTPException(status_code=404) + + raise exc + + if stat_result and stat.S_ISREG(stat_result.st_mode): + # We have a static file to serve. + return self.file_response(full_path, stat_result, scope) + + elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html: + # We're in HTML mode, and have got a directory URL. + # Check if we have 'index.html' file to serve. + index_path = os.path.join(path, "index.html") + full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path) + if stat_result is not None and stat.S_ISREG(stat_result.st_mode): + if not scope["path"].endswith("/"): + # Directory URLs should redirect to always end in "/". + url = URL(scope=scope) + url = url.replace(path=url.path + "/") + return RedirectResponse(url=url) + return self.file_response(full_path, stat_result, scope) + + if self.html: + # Check for '404.html' if we're in HTML mode. + full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html") + if stat_result and stat.S_ISREG(stat_result.st_mode): + return FileResponse(full_path, stat_result=stat_result, status_code=404) + raise HTTPException(status_code=404) + + def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]: + for directory in self.all_directories: + joined_path = os.path.join(directory, path) + if self.follow_symlink: + full_path = os.path.abspath(joined_path) + directory = os.path.abspath(directory) + else: + full_path = os.path.realpath(joined_path) + directory = os.path.realpath(directory) + if os.path.commonpath([full_path, directory]) != str(directory): + # Don't allow misbehaving clients to break out of the static files directory. + continue + try: + return full_path, os.stat(full_path) + except (FileNotFoundError, NotADirectoryError): + continue + return "", None + + def file_response( + self, + full_path: PathLike, + stat_result: os.stat_result, + scope: Scope, + status_code: int = 200, + ) -> Response: + request_headers = Headers(scope=scope) + + response = FileResponse(full_path, status_code=status_code, stat_result=stat_result) + if self.is_not_modified(response.headers, request_headers): + return NotModifiedResponse(response.headers) + return response + + async def check_config(self) -> None: + """ + Perform a one-off configuration check that StaticFiles is actually + pointed at a directory, so that we can raise loud errors rather than + just returning 404 responses. + """ + if self.directory is None: + return + + try: + stat_result = await anyio.to_thread.run_sync(os.stat, self.directory) + except FileNotFoundError: + raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.") + if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)): + raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.") + + def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool: + """ + Given the request and response headers, return `True` if an HTTP + "Not Modified" response could be returned instead. + """ + try: + if_none_match = request_headers["if-none-match"] + etag = response_headers["etag"] + if etag in [tag.strip(" W/") for tag in if_none_match.split(",")]: + return True + except KeyError: + pass + + try: + if_modified_since = parsedate(request_headers["if-modified-since"]) + last_modified = parsedate(response_headers["last-modified"]) + if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified: + return True + except KeyError: + pass + + return False diff --git a/.venv/lib/python3.12/site-packages/starlette/status.py b/.venv/lib/python3.12/site-packages/starlette/status.py new file mode 100644 index 00000000..54c1fb7d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/status.py @@ -0,0 +1,95 @@ +""" +HTTP codes +See HTTP Status Code Registry: +https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml + +And RFC 2324 - https://tools.ietf.org/html/rfc2324 +""" + +from __future__ import annotations + +HTTP_100_CONTINUE = 100 +HTTP_101_SWITCHING_PROTOCOLS = 101 +HTTP_102_PROCESSING = 102 +HTTP_103_EARLY_HINTS = 103 +HTTP_200_OK = 200 +HTTP_201_CREATED = 201 +HTTP_202_ACCEPTED = 202 +HTTP_203_NON_AUTHORITATIVE_INFORMATION = 203 +HTTP_204_NO_CONTENT = 204 +HTTP_205_RESET_CONTENT = 205 +HTTP_206_PARTIAL_CONTENT = 206 +HTTP_207_MULTI_STATUS = 207 +HTTP_208_ALREADY_REPORTED = 208 +HTTP_226_IM_USED = 226 +HTTP_300_MULTIPLE_CHOICES = 300 +HTTP_301_MOVED_PERMANENTLY = 301 +HTTP_302_FOUND = 302 +HTTP_303_SEE_OTHER = 303 +HTTP_304_NOT_MODIFIED = 304 +HTTP_305_USE_PROXY = 305 +HTTP_306_RESERVED = 306 +HTTP_307_TEMPORARY_REDIRECT = 307 +HTTP_308_PERMANENT_REDIRECT = 308 +HTTP_400_BAD_REQUEST = 400 +HTTP_401_UNAUTHORIZED = 401 +HTTP_402_PAYMENT_REQUIRED = 402 +HTTP_403_FORBIDDEN = 403 +HTTP_404_NOT_FOUND = 404 +HTTP_405_METHOD_NOT_ALLOWED = 405 +HTTP_406_NOT_ACCEPTABLE = 406 +HTTP_407_PROXY_AUTHENTICATION_REQUIRED = 407 +HTTP_408_REQUEST_TIMEOUT = 408 +HTTP_409_CONFLICT = 409 +HTTP_410_GONE = 410 +HTTP_411_LENGTH_REQUIRED = 411 +HTTP_412_PRECONDITION_FAILED = 412 +HTTP_413_REQUEST_ENTITY_TOO_LARGE = 413 +HTTP_414_REQUEST_URI_TOO_LONG = 414 +HTTP_415_UNSUPPORTED_MEDIA_TYPE = 415 +HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE = 416 +HTTP_417_EXPECTATION_FAILED = 417 +HTTP_418_IM_A_TEAPOT = 418 +HTTP_421_MISDIRECTED_REQUEST = 421 +HTTP_422_UNPROCESSABLE_ENTITY = 422 +HTTP_423_LOCKED = 423 +HTTP_424_FAILED_DEPENDENCY = 424 +HTTP_425_TOO_EARLY = 425 +HTTP_426_UPGRADE_REQUIRED = 426 +HTTP_428_PRECONDITION_REQUIRED = 428 +HTTP_429_TOO_MANY_REQUESTS = 429 +HTTP_431_REQUEST_HEADER_FIELDS_TOO_LARGE = 431 +HTTP_451_UNAVAILABLE_FOR_LEGAL_REASONS = 451 +HTTP_500_INTERNAL_SERVER_ERROR = 500 +HTTP_501_NOT_IMPLEMENTED = 501 +HTTP_502_BAD_GATEWAY = 502 +HTTP_503_SERVICE_UNAVAILABLE = 503 +HTTP_504_GATEWAY_TIMEOUT = 504 +HTTP_505_HTTP_VERSION_NOT_SUPPORTED = 505 +HTTP_506_VARIANT_ALSO_NEGOTIATES = 506 +HTTP_507_INSUFFICIENT_STORAGE = 507 +HTTP_508_LOOP_DETECTED = 508 +HTTP_510_NOT_EXTENDED = 510 +HTTP_511_NETWORK_AUTHENTICATION_REQUIRED = 511 + + +""" +WebSocket codes +https://www.iana.org/assignments/websocket/websocket.xml#close-code-number +https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent +""" +WS_1000_NORMAL_CLOSURE = 1000 +WS_1001_GOING_AWAY = 1001 +WS_1002_PROTOCOL_ERROR = 1002 +WS_1003_UNSUPPORTED_DATA = 1003 +WS_1005_NO_STATUS_RCVD = 1005 +WS_1006_ABNORMAL_CLOSURE = 1006 +WS_1007_INVALID_FRAME_PAYLOAD_DATA = 1007 +WS_1008_POLICY_VIOLATION = 1008 +WS_1009_MESSAGE_TOO_BIG = 1009 +WS_1010_MANDATORY_EXT = 1010 +WS_1011_INTERNAL_ERROR = 1011 +WS_1012_SERVICE_RESTART = 1012 +WS_1013_TRY_AGAIN_LATER = 1013 +WS_1014_BAD_GATEWAY = 1014 +WS_1015_TLS_HANDSHAKE = 1015 diff --git a/.venv/lib/python3.12/site-packages/starlette/templating.py b/.venv/lib/python3.12/site-packages/starlette/templating.py new file mode 100644 index 00000000..6b01aac9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/templating.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import typing +import warnings +from os import PathLike + +from starlette.background import BackgroundTask +from starlette.datastructures import URL +from starlette.requests import Request +from starlette.responses import HTMLResponse +from starlette.types import Receive, Scope, Send + +try: + import jinja2 + + # @contextfunction was renamed to @pass_context in Jinja 3.0, and was removed in 3.1 + # hence we try to get pass_context (most installs will be >=3.1) + # and fall back to contextfunction, + # adding a type ignore for mypy to let us access an attribute that may not exist + if hasattr(jinja2, "pass_context"): + pass_context = jinja2.pass_context + else: # pragma: no cover + pass_context = jinja2.contextfunction # type: ignore[attr-defined] +except ModuleNotFoundError: # pragma: no cover + jinja2 = None # type: ignore[assignment] + + +class _TemplateResponse(HTMLResponse): + def __init__( + self, + template: typing.Any, + context: dict[str, typing.Any], + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + ): + self.template = template + self.context = context + content = template.render(context) + super().__init__(content, status_code, headers, media_type, background) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + request = self.context.get("request", {}) + extensions = request.get("extensions", {}) + if "http.response.debug" in extensions: # pragma: no branch + await send( + { + "type": "http.response.debug", + "info": { + "template": self.template, + "context": self.context, + }, + } + ) + await super().__call__(scope, receive, send) + + +class Jinja2Templates: + """ + templates = Jinja2Templates("templates") + + return templates.TemplateResponse("index.html", {"request": request}) + """ + + @typing.overload + def __init__( + self, + directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]], + *, + context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None, + **env_options: typing.Any, + ) -> None: ... + + @typing.overload + def __init__( + self, + *, + env: jinja2.Environment, + context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None, + ) -> None: ... + + def __init__( + self, + directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]] | None = None, + *, + context_processors: list[typing.Callable[[Request], dict[str, typing.Any]]] | None = None, + env: jinja2.Environment | None = None, + **env_options: typing.Any, + ) -> None: + if env_options: + warnings.warn( + "Extra environment options are deprecated. Use a preconfigured jinja2.Environment instead.", + DeprecationWarning, + ) + assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates" + assert bool(directory) ^ bool(env), "either 'directory' or 'env' arguments must be passed" + self.context_processors = context_processors or [] + if directory is not None: + self.env = self._create_env(directory, **env_options) + elif env is not None: # pragma: no branch + self.env = env + + self._setup_env_defaults(self.env) + + def _create_env( + self, + directory: str | PathLike[str] | typing.Sequence[str | PathLike[str]], + **env_options: typing.Any, + ) -> jinja2.Environment: + loader = jinja2.FileSystemLoader(directory) + env_options.setdefault("loader", loader) + env_options.setdefault("autoescape", True) + + return jinja2.Environment(**env_options) + + def _setup_env_defaults(self, env: jinja2.Environment) -> None: + @pass_context + def url_for( + context: dict[str, typing.Any], + name: str, + /, + **path_params: typing.Any, + ) -> URL: + request: Request = context["request"] + return request.url_for(name, **path_params) + + env.globals.setdefault("url_for", url_for) + + def get_template(self, name: str) -> jinja2.Template: + return self.env.get_template(name) + + @typing.overload + def TemplateResponse( + self, + request: Request, + name: str, + context: dict[str, typing.Any] | None = None, + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + ) -> _TemplateResponse: ... + + @typing.overload + def TemplateResponse( + self, + name: str, + context: dict[str, typing.Any] | None = None, + status_code: int = 200, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + ) -> _TemplateResponse: + # Deprecated usage + ... + + def TemplateResponse(self, *args: typing.Any, **kwargs: typing.Any) -> _TemplateResponse: + if args: + if isinstance(args[0], str): # the first argument is template name (old style) + warnings.warn( + "The `name` is not the first parameter anymore. " + "The first parameter should be the `Request` instance.\n" + 'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.', + DeprecationWarning, + ) + + name = args[0] + context = args[1] if len(args) > 1 else kwargs.get("context", {}) + status_code = args[2] if len(args) > 2 else kwargs.get("status_code", 200) + headers = args[2] if len(args) > 2 else kwargs.get("headers") + media_type = args[3] if len(args) > 3 else kwargs.get("media_type") + background = args[4] if len(args) > 4 else kwargs.get("background") + + if "request" not in context: + raise ValueError('context must include a "request" key') + request = context["request"] + else: # the first argument is a request instance (new style) + request = args[0] + name = args[1] if len(args) > 1 else kwargs["name"] + context = args[2] if len(args) > 2 else kwargs.get("context", {}) + status_code = args[3] if len(args) > 3 else kwargs.get("status_code", 200) + headers = args[4] if len(args) > 4 else kwargs.get("headers") + media_type = args[5] if len(args) > 5 else kwargs.get("media_type") + background = args[6] if len(args) > 6 else kwargs.get("background") + else: # all arguments are kwargs + if "request" not in kwargs: + warnings.warn( + "The `TemplateResponse` now requires the `request` argument.\n" + 'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.', + DeprecationWarning, + ) + if "request" not in kwargs.get("context", {}): + raise ValueError('context must include a "request" key') + + context = kwargs.get("context", {}) + request = kwargs.get("request", context.get("request")) + name = typing.cast(str, kwargs["name"]) + status_code = kwargs.get("status_code", 200) + headers = kwargs.get("headers") + media_type = kwargs.get("media_type") + background = kwargs.get("background") + + context.setdefault("request", request) + for context_processor in self.context_processors: + context.update(context_processor(request)) + + template = self.get_template(name) + return _TemplateResponse( + template, + context, + status_code=status_code, + headers=headers, + media_type=media_type, + background=background, + ) diff --git a/.venv/lib/python3.12/site-packages/starlette/testclient.py b/.venv/lib/python3.12/site-packages/starlette/testclient.py new file mode 100644 index 00000000..d54025e5 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/testclient.py @@ -0,0 +1,731 @@ +from __future__ import annotations + +import contextlib +import inspect +import io +import json +import math +import sys +import typing +import warnings +from concurrent.futures import Future +from types import GeneratorType +from urllib.parse import unquote, urljoin + +import anyio +import anyio.abc +import anyio.from_thread +from anyio.streams.stapled import StapledObjectStream + +from starlette._utils import is_async_callable +from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.websockets import WebSocketDisconnect + +if sys.version_info >= (3, 10): # pragma: no cover + from typing import TypeGuard +else: # pragma: no cover + from typing_extensions import TypeGuard + +try: + import httpx +except ModuleNotFoundError: # pragma: no cover + raise RuntimeError( + "The starlette.testclient module requires the httpx package to be installed.\n" + "You can install this with:\n" + " $ pip install httpx\n" + ) +_PortalFactoryType = typing.Callable[[], typing.ContextManager[anyio.abc.BlockingPortal]] + +ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]] +ASGI2App = typing.Callable[[Scope], ASGIInstance] +ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] + + +_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]] + + +def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]: + if inspect.isclass(app): + return hasattr(app, "__await__") + return is_async_callable(app) + + +class _WrapASGI2: + """ + Provide an ASGI3 interface onto an ASGI2 app. + """ + + def __init__(self, app: ASGI2App) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + instance = self.app(scope) + await instance(receive, send) + + +class _AsyncBackend(typing.TypedDict): + backend: str + backend_options: dict[str, typing.Any] + + +class _Upgrade(Exception): + def __init__(self, session: WebSocketTestSession) -> None: + self.session = session + + +class WebSocketDenialResponse( # type: ignore[misc] + httpx.Response, + WebSocketDisconnect, +): + """ + A special case of `WebSocketDisconnect`, raised in the `TestClient` if the + `WebSocket` is closed before being accepted with a `send_denial_response()`. + """ + + +class WebSocketTestSession: + def __init__( + self, + app: ASGI3App, + scope: Scope, + portal_factory: _PortalFactoryType, + ) -> None: + self.app = app + self.scope = scope + self.accepted_subprotocol = None + self.portal_factory = portal_factory + self.extra_headers = None + + def __enter__(self) -> WebSocketTestSession: + with contextlib.ExitStack() as stack: + self.portal = portal = stack.enter_context(self.portal_factory()) + fut, cs = portal.start_task(self._run) + stack.callback(fut.result) + stack.callback(portal.call, cs.cancel) + self.send({"type": "websocket.connect"}) + message = self.receive() + self._raise_on_close(message) + self.accepted_subprotocol = message.get("subprotocol", None) + self.extra_headers = message.get("headers", None) + stack.callback(self.close, 1000) + self.exit_stack = stack.pop_all() + return self + + def __exit__(self, *args: typing.Any) -> bool | None: + return self.exit_stack.__exit__(*args) + + async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: + """ + The sub-thread in which the websocket session runs. + """ + send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) + send_tx, send_rx = send + receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) + receive_tx, receive_rx = receive + with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs: + self._receive_tx = receive_tx + self._send_rx = send_rx + task_status.started(cs) + await self.app(self.scope, receive_rx.receive, send_tx.send) + + # wait for cs.cancel to be called before closing streams + await anyio.sleep_forever() + + def _raise_on_close(self, message: Message) -> None: + if message["type"] == "websocket.close": + raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", "")) + elif message["type"] == "websocket.http.response.start": + status_code: int = message["status"] + headers: list[tuple[bytes, bytes]] = message["headers"] + body: list[bytes] = [] + while True: + message = self.receive() + assert message["type"] == "websocket.http.response.body" + body.append(message["body"]) + if not message.get("more_body", False): + break + raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body)) + + def send(self, message: Message) -> None: + self.portal.call(self._receive_tx.send, message) + + def send_text(self, data: str) -> None: + self.send({"type": "websocket.receive", "text": data}) + + def send_bytes(self, data: bytes) -> None: + self.send({"type": "websocket.receive", "bytes": data}) + + def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None: + text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) + if mode == "text": + self.send({"type": "websocket.receive", "text": text}) + else: + self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) + + def close(self, code: int = 1000, reason: str | None = None) -> None: + self.send({"type": "websocket.disconnect", "code": code, "reason": reason}) + + def receive(self) -> Message: + return self.portal.call(self._send_rx.receive) + + def receive_text(self) -> str: + message = self.receive() + self._raise_on_close(message) + return typing.cast(str, message["text"]) + + def receive_bytes(self) -> bytes: + message = self.receive() + self._raise_on_close(message) + return typing.cast(bytes, message["bytes"]) + + def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any: + message = self.receive() + self._raise_on_close(message) + if mode == "text": + text = message["text"] + else: + text = message["bytes"].decode("utf-8") + return json.loads(text) + + +class _TestClientTransport(httpx.BaseTransport): + def __init__( + self, + app: ASGI3App, + portal_factory: _PortalFactoryType, + raise_server_exceptions: bool = True, + root_path: str = "", + *, + client: tuple[str, int], + app_state: dict[str, typing.Any], + ) -> None: + self.app = app + self.raise_server_exceptions = raise_server_exceptions + self.root_path = root_path + self.portal_factory = portal_factory + self.app_state = app_state + self.client = client + + def handle_request(self, request: httpx.Request) -> httpx.Response: + scheme = request.url.scheme + netloc = request.url.netloc.decode(encoding="ascii") + path = request.url.path + raw_path = request.url.raw_path + query = request.url.query.decode(encoding="ascii") + + default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] + + if ":" in netloc: + host, port_string = netloc.split(":", 1) + port = int(port_string) + else: + host = netloc + port = default_port + + # Include the 'host' header. + if "host" in request.headers: + headers: list[tuple[bytes, bytes]] = [] + elif port == default_port: # pragma: no cover + headers = [(b"host", host.encode())] + else: # pragma: no cover + headers = [(b"host", (f"{host}:{port}").encode())] + + # Include other request headers. + headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()] + + scope: dict[str, typing.Any] + + if scheme in {"ws", "wss"}: + subprotocol = request.headers.get("sec-websocket-protocol", None) + if subprotocol is None: + subprotocols: typing.Sequence[str] = [] + else: + subprotocols = [value.strip() for value in subprotocol.split(",")] + scope = { + "type": "websocket", + "path": unquote(path), + "raw_path": raw_path.split(b"?", 1)[0], + "root_path": self.root_path, + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": self.client, + "server": [host, port], + "subprotocols": subprotocols, + "state": self.app_state.copy(), + "extensions": {"websocket.http.response": {}}, + } + session = WebSocketTestSession(self.app, scope, self.portal_factory) + raise _Upgrade(session) + + scope = { + "type": "http", + "http_version": "1.1", + "method": request.method, + "path": unquote(path), + "raw_path": raw_path.split(b"?", 1)[0], + "root_path": self.root_path, + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": self.client, + "server": [host, port], + "extensions": {"http.response.debug": {}}, + "state": self.app_state.copy(), + } + + request_complete = False + response_started = False + response_complete: anyio.Event + raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()} + template = None + context = None + + async def receive() -> Message: + nonlocal request_complete + + if request_complete: + if not response_complete.is_set(): + await response_complete.wait() + return {"type": "http.disconnect"} + + body = request.read() + if isinstance(body, str): + body_bytes: bytes = body.encode("utf-8") # pragma: no cover + elif body is None: + body_bytes = b"" # pragma: no cover + elif isinstance(body, GeneratorType): + try: # pragma: no cover + chunk = body.send(None) + if isinstance(chunk, str): + chunk = chunk.encode("utf-8") + return {"type": "http.request", "body": chunk, "more_body": True} + except StopIteration: # pragma: no cover + request_complete = True + return {"type": "http.request", "body": b""} + else: + body_bytes = body + + request_complete = True + return {"type": "http.request", "body": body_bytes} + + async def send(message: Message) -> None: + nonlocal raw_kwargs, response_started, template, context + + if message["type"] == "http.response.start": + assert not response_started, 'Received multiple "http.response.start" messages.' + raw_kwargs["status_code"] = message["status"] + raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])] + response_started = True + elif message["type"] == "http.response.body": + assert response_started, 'Received "http.response.body" without "http.response.start".' + assert not response_complete.is_set(), 'Received "http.response.body" after response completed.' + body = message.get("body", b"") + more_body = message.get("more_body", False) + if request.method != "HEAD": + raw_kwargs["stream"].write(body) + if not more_body: + raw_kwargs["stream"].seek(0) + response_complete.set() + elif message["type"] == "http.response.debug": + template = message["info"]["template"] + context = message["info"]["context"] + + try: + with self.portal_factory() as portal: + response_complete = portal.call(anyio.Event) + portal.call(self.app, scope, receive, send) + except BaseException as exc: + if self.raise_server_exceptions: + raise exc + + if self.raise_server_exceptions: + assert response_started, "TestClient did not receive any response." + elif not response_started: + raw_kwargs = { + "status_code": 500, + "headers": [], + "stream": io.BytesIO(), + } + + raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) + + response = httpx.Response(**raw_kwargs, request=request) + if template is not None: + response.template = template # type: ignore[attr-defined] + response.context = context # type: ignore[attr-defined] + return response + + +class TestClient(httpx.Client): + __test__ = False + task: Future[None] + portal: anyio.abc.BlockingPortal | None = None + + def __init__( + self, + app: ASGIApp, + base_url: str = "http://testserver", + raise_server_exceptions: bool = True, + root_path: str = "", + backend: typing.Literal["asyncio", "trio"] = "asyncio", + backend_options: dict[str, typing.Any] | None = None, + cookies: httpx._types.CookieTypes | None = None, + headers: dict[str, str] | None = None, + follow_redirects: bool = True, + client: tuple[str, int] = ("testclient", 50000), + ) -> None: + self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {}) + if _is_asgi3(app): + asgi_app = app + else: + app = typing.cast(ASGI2App, app) # type: ignore[assignment] + asgi_app = _WrapASGI2(app) # type: ignore[arg-type] + self.app = asgi_app + self.app_state: dict[str, typing.Any] = {} + transport = _TestClientTransport( + self.app, + portal_factory=self._portal_factory, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, + app_state=self.app_state, + client=client, + ) + if headers is None: + headers = {} + headers.setdefault("user-agent", "testclient") + super().__init__( + base_url=base_url, + headers=headers, + transport=transport, + follow_redirects=follow_redirects, + cookies=cookies, + ) + + @contextlib.contextmanager + def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: + if self.portal is not None: + yield self.portal + else: + with anyio.from_thread.start_blocking_portal(**self.async_backend) as portal: + yield portal + + def request( # type: ignore[override] + self, + method: str, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + if timeout is not httpx.USE_CLIENT_DEFAULT: + warnings.warn( + "You should not use the 'timeout' argument with the TestClient. " + "See https://github.com/encode/starlette/issues/1108 for more information.", + DeprecationWarning, + ) + url = self._merge_url(url) + return super().request( + method, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def get( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return super().get( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def options( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return super().options( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def head( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return super().head( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def post( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return super().post( + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def put( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return super().put( + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def patch( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return super().patch( + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def delete( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return super().delete( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def websocket_connect( + self, + url: str, + subprotocols: typing.Sequence[str] | None = None, + **kwargs: typing.Any, + ) -> WebSocketTestSession: + url = urljoin("ws://testserver", url) + headers = kwargs.get("headers", {}) + headers.setdefault("connection", "upgrade") + headers.setdefault("sec-websocket-key", "testserver==") + headers.setdefault("sec-websocket-version", "13") + if subprotocols is not None: + headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) + kwargs["headers"] = headers + try: + super().request("GET", url, **kwargs) + except _Upgrade as exc: + session = exc.session + else: + raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover + + return session + + def __enter__(self) -> TestClient: + with contextlib.ExitStack() as stack: + self.portal = portal = stack.enter_context(anyio.from_thread.start_blocking_portal(**self.async_backend)) + + @stack.callback + def reset_portal() -> None: + self.portal = None + + send: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None] = ( + anyio.create_memory_object_stream(math.inf) + ) + receive: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]] = ( + anyio.create_memory_object_stream(math.inf) + ) + for channel in (*send, *receive): + stack.callback(channel.close) + self.stream_send = StapledObjectStream(*send) + self.stream_receive = StapledObjectStream(*receive) + self.task = portal.start_task_soon(self.lifespan) + portal.call(self.wait_startup) + + @stack.callback + def wait_shutdown() -> None: + portal.call(self.wait_shutdown) + + self.exit_stack = stack.pop_all() + + return self + + def __exit__(self, *args: typing.Any) -> None: + self.exit_stack.close() + + async def lifespan(self) -> None: + scope = {"type": "lifespan", "state": self.app_state} + try: + await self.app(scope, self.stream_receive.receive, self.stream_send.send) + finally: + await self.stream_send.send(None) + + async def wait_startup(self) -> None: + await self.stream_receive.send({"type": "lifespan.startup"}) + + async def receive() -> typing.Any: + message = await self.stream_send.receive() + if message is None: + self.task.result() + return message + + message = await receive() + assert message["type"] in ( + "lifespan.startup.complete", + "lifespan.startup.failed", + ) + if message["type"] == "lifespan.startup.failed": + await receive() + + async def wait_shutdown(self) -> None: + async def receive() -> typing.Any: + message = await self.stream_send.receive() + if message is None: + self.task.result() + return message + + await self.stream_receive.send({"type": "lifespan.shutdown"}) + message = await receive() + assert message["type"] in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + ) + if message["type"] == "lifespan.shutdown.failed": + await receive() diff --git a/.venv/lib/python3.12/site-packages/starlette/types.py b/.venv/lib/python3.12/site-packages/starlette/types.py new file mode 100644 index 00000000..893f8729 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/types.py @@ -0,0 +1,24 @@ +import typing + +if typing.TYPE_CHECKING: + from starlette.requests import Request + from starlette.responses import Response + from starlette.websockets import WebSocket + +AppType = typing.TypeVar("AppType") + +Scope = typing.MutableMapping[str, typing.Any] +Message = typing.MutableMapping[str, typing.Any] + +Receive = typing.Callable[[], typing.Awaitable[Message]] +Send = typing.Callable[[Message], typing.Awaitable[None]] + +ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] + +StatelessLifespan = typing.Callable[[AppType], typing.AsyncContextManager[None]] +StatefulLifespan = typing.Callable[[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]] +Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]] + +HTTPExceptionHandler = typing.Callable[["Request", Exception], "Response | typing.Awaitable[Response]"] +WebSocketExceptionHandler = typing.Callable[["WebSocket", Exception], typing.Awaitable[None]] +ExceptionHandler = typing.Union[HTTPExceptionHandler, WebSocketExceptionHandler] diff --git a/.venv/lib/python3.12/site-packages/starlette/websockets.py b/.venv/lib/python3.12/site-packages/starlette/websockets.py new file mode 100644 index 00000000..6b46f4ea --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/websockets.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +import enum +import json +import typing + +from starlette.requests import HTTPConnection +from starlette.responses import Response +from starlette.types import Message, Receive, Scope, Send + + +class WebSocketState(enum.Enum): + CONNECTING = 0 + CONNECTED = 1 + DISCONNECTED = 2 + RESPONSE = 3 + + +class WebSocketDisconnect(Exception): + def __init__(self, code: int = 1000, reason: str | None = None) -> None: + self.code = code + self.reason = reason or "" + + +class WebSocket(HTTPConnection): + def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: + super().__init__(scope) + assert scope["type"] == "websocket" + self._receive = receive + self._send = send + self.client_state = WebSocketState.CONNECTING + self.application_state = WebSocketState.CONNECTING + + async def receive(self) -> Message: + """ + Receive ASGI websocket messages, ensuring valid state transitions. + """ + if self.client_state == WebSocketState.CONNECTING: + message = await self._receive() + message_type = message["type"] + if message_type != "websocket.connect": + raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}') + self.client_state = WebSocketState.CONNECTED + return message + elif self.client_state == WebSocketState.CONNECTED: + message = await self._receive() + message_type = message["type"] + if message_type not in {"websocket.receive", "websocket.disconnect"}: + raise RuntimeError( + f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}' + ) + if message_type == "websocket.disconnect": + self.client_state = WebSocketState.DISCONNECTED + return message + else: + raise RuntimeError('Cannot call "receive" once a disconnect message has been received.') + + async def send(self, message: Message) -> None: + """ + Send ASGI websocket messages, ensuring valid state transitions. + """ + if self.application_state == WebSocketState.CONNECTING: + message_type = message["type"] + if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}: + raise RuntimeError( + 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", ' + f"but got {message_type!r}" + ) + if message_type == "websocket.close": + self.application_state = WebSocketState.DISCONNECTED + elif message_type == "websocket.http.response.start": + self.application_state = WebSocketState.RESPONSE + else: + self.application_state = WebSocketState.CONNECTED + await self._send(message) + elif self.application_state == WebSocketState.CONNECTED: + message_type = message["type"] + if message_type not in {"websocket.send", "websocket.close"}: + raise RuntimeError( + f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}' + ) + if message_type == "websocket.close": + self.application_state = WebSocketState.DISCONNECTED + try: + await self._send(message) + except OSError: + self.application_state = WebSocketState.DISCONNECTED + raise WebSocketDisconnect(code=1006) + elif self.application_state == WebSocketState.RESPONSE: + message_type = message["type"] + if message_type != "websocket.http.response.body": + raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}') + if not message.get("more_body", False): + self.application_state = WebSocketState.DISCONNECTED + await self._send(message) + else: + raise RuntimeError('Cannot call "send" once a close message has been sent.') + + async def accept( + self, + subprotocol: str | None = None, + headers: typing.Iterable[tuple[bytes, bytes]] | None = None, + ) -> None: + headers = headers or [] + + if self.client_state == WebSocketState.CONNECTING: # pragma: no branch + # If we haven't yet seen the 'connect' message, then wait for it first. + await self.receive() + await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}) + + def _raise_on_disconnect(self, message: Message) -> None: + if message["type"] == "websocket.disconnect": + raise WebSocketDisconnect(message["code"], message.get("reason")) + + async def receive_text(self) -> str: + if self.application_state != WebSocketState.CONNECTED: + raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') + message = await self.receive() + self._raise_on_disconnect(message) + return typing.cast(str, message["text"]) + + async def receive_bytes(self) -> bytes: + if self.application_state != WebSocketState.CONNECTED: + raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') + message = await self.receive() + self._raise_on_disconnect(message) + return typing.cast(bytes, message["bytes"]) + + async def receive_json(self, mode: str = "text") -> typing.Any: + if mode not in {"text", "binary"}: + raise RuntimeError('The "mode" argument should be "text" or "binary".') + if self.application_state != WebSocketState.CONNECTED: + raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') + message = await self.receive() + self._raise_on_disconnect(message) + + if mode == "text": + text = message["text"] + else: + text = message["bytes"].decode("utf-8") + return json.loads(text) + + async def iter_text(self) -> typing.AsyncIterator[str]: + try: + while True: + yield await self.receive_text() + except WebSocketDisconnect: + pass + + async def iter_bytes(self) -> typing.AsyncIterator[bytes]: + try: + while True: + yield await self.receive_bytes() + except WebSocketDisconnect: + pass + + async def iter_json(self) -> typing.AsyncIterator[typing.Any]: + try: + while True: + yield await self.receive_json() + except WebSocketDisconnect: + pass + + async def send_text(self, data: str) -> None: + await self.send({"type": "websocket.send", "text": data}) + + async def send_bytes(self, data: bytes) -> None: + await self.send({"type": "websocket.send", "bytes": data}) + + async def send_json(self, data: typing.Any, mode: str = "text") -> None: + if mode not in {"text", "binary"}: + raise RuntimeError('The "mode" argument should be "text" or "binary".') + text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) + if mode == "text": + await self.send({"type": "websocket.send", "text": text}) + else: + await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) + + async def close(self, code: int = 1000, reason: str | None = None) -> None: + await self.send({"type": "websocket.close", "code": code, "reason": reason or ""}) + + async def send_denial_response(self, response: Response) -> None: + if "websocket.http.response" in self.scope.get("extensions", {}): + await response(self.scope, self.receive, self.send) + else: + raise RuntimeError("The server doesn't support the Websocket Denial Response extension.") + + +class WebSocketClose: + def __init__(self, code: int = 1000, reason: str | None = None) -> None: + self.code = code + self.reason = reason or "" + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + await send({"type": "websocket.close", "code": self.code, "reason": self.reason}) |