diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/starlette/datastructures.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/starlette/datastructures.py | 674 |
1 files changed, 674 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/starlette/datastructures.py b/.venv/lib/python3.12/site-packages/starlette/datastructures.py new file mode 100644 index 00000000..f5d74d25 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/starlette/datastructures.py @@ -0,0 +1,674 @@ +from __future__ import annotations + +import typing +from shlex import shlex +from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit + +from starlette.concurrency import run_in_threadpool +from starlette.types import Scope + + +class Address(typing.NamedTuple): + host: str + port: int + + +_KeyType = typing.TypeVar("_KeyType") +# Mapping keys are invariant but their values are covariant since +# you can only read them +# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()` +_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True) + + +class URL: + def __init__( + self, + url: str = "", + scope: Scope | None = None, + **components: typing.Any, + ) -> None: + if scope is not None: + assert not url, 'Cannot set both "url" and "scope".' + assert not components, 'Cannot set both "scope" and "**components".' + scheme = scope.get("scheme", "http") + server = scope.get("server", None) + path = scope["path"] + query_string = scope.get("query_string", b"") + + host_header = None + for key, value in scope["headers"]: + if key == b"host": + host_header = value.decode("latin-1") + break + + if host_header is not None: + url = f"{scheme}://{host_header}{path}" + elif server is None: + url = path + else: + host, port = server + default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] + if port == default_port: + url = f"{scheme}://{host}{path}" + else: + url = f"{scheme}://{host}:{port}{path}" + + if query_string: + url += "?" + query_string.decode() + elif components: + assert not url, 'Cannot set both "url" and "**components".' + url = URL("").replace(**components).components.geturl() + + self._url = url + + @property + def components(self) -> SplitResult: + if not hasattr(self, "_components"): + self._components = urlsplit(self._url) + return self._components + + @property + def scheme(self) -> str: + return self.components.scheme + + @property + def netloc(self) -> str: + return self.components.netloc + + @property + def path(self) -> str: + return self.components.path + + @property + def query(self) -> str: + return self.components.query + + @property + def fragment(self) -> str: + return self.components.fragment + + @property + def username(self) -> None | str: + return self.components.username + + @property + def password(self) -> None | str: + return self.components.password + + @property + def hostname(self) -> None | str: + return self.components.hostname + + @property + def port(self) -> int | None: + return self.components.port + + @property + def is_secure(self) -> bool: + return self.scheme in ("https", "wss") + + def replace(self, **kwargs: typing.Any) -> URL: + if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs: + hostname = kwargs.pop("hostname", None) + port = kwargs.pop("port", self.port) + username = kwargs.pop("username", self.username) + password = kwargs.pop("password", self.password) + + if hostname is None: + netloc = self.netloc + _, _, hostname = netloc.rpartition("@") + + if hostname[-1] != "]": + hostname = hostname.rsplit(":", 1)[0] + + netloc = hostname + if port is not None: + netloc += f":{port}" + if username is not None: + userpass = username + if password is not None: + userpass += f":{password}" + netloc = f"{userpass}@{netloc}" + + kwargs["netloc"] = netloc + + components = self.components._replace(**kwargs) + return self.__class__(components.geturl()) + + def include_query_params(self, **kwargs: typing.Any) -> URL: + params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) + params.update({str(key): str(value) for key, value in kwargs.items()}) + query = urlencode(params.multi_items()) + return self.replace(query=query) + + def replace_query_params(self, **kwargs: typing.Any) -> URL: + query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) + return self.replace(query=query) + + def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL: + if isinstance(keys, str): + keys = [keys] + params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) + for key in keys: + params.pop(key, None) + query = urlencode(params.multi_items()) + return self.replace(query=query) + + def __eq__(self, other: typing.Any) -> bool: + return str(self) == str(other) + + def __str__(self) -> str: + return self._url + + def __repr__(self) -> str: + url = str(self) + if self.password: + url = str(self.replace(password="********")) + return f"{self.__class__.__name__}({repr(url)})" + + +class URLPath(str): + """ + A URL path string that may also hold an associated protocol and/or host. + Used by the routing to return `url_path_for` matches. + """ + + def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath: + assert protocol in ("http", "websocket", "") + return str.__new__(cls, path) + + def __init__(self, path: str, protocol: str = "", host: str = "") -> None: + self.protocol = protocol + self.host = host + + def make_absolute_url(self, base_url: str | URL) -> URL: + if isinstance(base_url, str): + base_url = URL(base_url) + if self.protocol: + scheme = { + "http": {True: "https", False: "http"}, + "websocket": {True: "wss", False: "ws"}, + }[self.protocol][base_url.is_secure] + else: + scheme = base_url.scheme + + netloc = self.host or base_url.netloc + path = base_url.path.rstrip("/") + str(self) + return URL(scheme=scheme, netloc=netloc, path=path) + + +class Secret: + """ + Holds a string value that should not be revealed in tracebacks etc. + You should cast the value to `str` at the point it is required. + """ + + def __init__(self, value: str): + self._value = value + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}('**********')" + + def __str__(self) -> str: + return self._value + + def __bool__(self) -> bool: + return bool(self._value) + + +class CommaSeparatedStrings(typing.Sequence[str]): + def __init__(self, value: str | typing.Sequence[str]): + if isinstance(value, str): + splitter = shlex(value, posix=True) + splitter.whitespace = "," + splitter.whitespace_split = True + self._items = [item.strip() for item in splitter] + else: + self._items = list(value) + + def __len__(self) -> int: + return len(self._items) + + def __getitem__(self, index: int | slice) -> typing.Any: + return self._items[index] + + def __iter__(self) -> typing.Iterator[str]: + return iter(self._items) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + items = [item for item in self] + return f"{class_name}({items!r})" + + def __str__(self) -> str: + return ", ".join(repr(item) for item in self) + + +class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]): + _dict: dict[_KeyType, _CovariantValueType] + + def __init__( + self, + *args: ImmutableMultiDict[_KeyType, _CovariantValueType] + | typing.Mapping[_KeyType, _CovariantValueType] + | typing.Iterable[tuple[_KeyType, _CovariantValueType]], + **kwargs: typing.Any, + ) -> None: + assert len(args) < 2, "Too many arguments." + + value: typing.Any = args[0] if args else [] + if kwargs: + value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items() + + if not value: + _items: list[tuple[typing.Any, typing.Any]] = [] + elif hasattr(value, "multi_items"): + value = typing.cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value) + _items = list(value.multi_items()) + elif hasattr(value, "items"): + value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value) + _items = list(value.items()) + else: + value = typing.cast("list[tuple[typing.Any, typing.Any]]", value) + _items = list(value) + + self._dict = {k: v for k, v in _items} + self._list = _items + + def getlist(self, key: typing.Any) -> list[_CovariantValueType]: + return [item_value for item_key, item_value in self._list if item_key == key] + + def keys(self) -> typing.KeysView[_KeyType]: + return self._dict.keys() + + def values(self) -> typing.ValuesView[_CovariantValueType]: + return self._dict.values() + + def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]: + return self._dict.items() + + def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]: + return list(self._list) + + def __getitem__(self, key: _KeyType) -> _CovariantValueType: + return self._dict[key] + + def __contains__(self, key: typing.Any) -> bool: + return key in self._dict + + def __iter__(self) -> typing.Iterator[_KeyType]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._dict) + + def __eq__(self, other: typing.Any) -> bool: + if not isinstance(other, self.__class__): + return False + return sorted(self._list) == sorted(other._list) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + items = self.multi_items() + return f"{class_name}({items!r})" + + +class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]): + def __setitem__(self, key: typing.Any, value: typing.Any) -> None: + self.setlist(key, [value]) + + def __delitem__(self, key: typing.Any) -> None: + self._list = [(k, v) for k, v in self._list if k != key] + del self._dict[key] + + def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + self._list = [(k, v) for k, v in self._list if k != key] + return self._dict.pop(key, default) + + def popitem(self) -> tuple[typing.Any, typing.Any]: + key, value = self._dict.popitem() + self._list = [(k, v) for k, v in self._list if k != key] + return key, value + + def poplist(self, key: typing.Any) -> list[typing.Any]: + values = [v for k, v in self._list if k == key] + self.pop(key) + return values + + def clear(self) -> None: + self._dict.clear() + self._list.clear() + + def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any: + if key not in self: + self._dict[key] = default + self._list.append((key, default)) + + return self[key] + + def setlist(self, key: typing.Any, values: list[typing.Any]) -> None: + if not values: + self.pop(key, None) + else: + existing_items = [(k, v) for (k, v) in self._list if k != key] + self._list = existing_items + [(key, value) for value in values] + self._dict[key] = values[-1] + + def append(self, key: typing.Any, value: typing.Any) -> None: + self._list.append((key, value)) + self._dict[key] = value + + def update( + self, + *args: MultiDict | typing.Mapping[typing.Any, typing.Any] | list[tuple[typing.Any, typing.Any]], + **kwargs: typing.Any, + ) -> None: + value = MultiDict(*args, **kwargs) + existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()] + self._list = existing_items + value.multi_items() + self._dict.update(value) + + +class QueryParams(ImmutableMultiDict[str, str]): + """ + An immutable multidict. + """ + + def __init__( + self, + *args: ImmutableMultiDict[typing.Any, typing.Any] + | typing.Mapping[typing.Any, typing.Any] + | list[tuple[typing.Any, typing.Any]] + | str + | bytes, + **kwargs: typing.Any, + ) -> None: + assert len(args) < 2, "Too many arguments." + + value = args[0] if args else [] + + if isinstance(value, str): + super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs) + elif isinstance(value, bytes): + super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs) + else: + super().__init__(*args, **kwargs) # type: ignore[arg-type] + self._list = [(str(k), str(v)) for k, v in self._list] + self._dict = {str(k): str(v) for k, v in self._dict.items()} + + def __str__(self) -> str: + return urlencode(self._list) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + query_string = str(self) + return f"{class_name}({query_string!r})" + + +class UploadFile: + """ + An uploaded file included as part of the request data. + """ + + def __init__( + self, + file: typing.BinaryIO, + *, + size: int | None = None, + filename: str | None = None, + headers: Headers | None = None, + ) -> None: + self.filename = filename + self.file = file + self.size = size + self.headers = headers or Headers() + + @property + def content_type(self) -> str | None: + return self.headers.get("content-type", None) + + @property + def _in_memory(self) -> bool: + # check for SpooledTemporaryFile._rolled + rolled_to_disk = getattr(self.file, "_rolled", True) + return not rolled_to_disk + + async def write(self, data: bytes) -> None: + if self.size is not None: + self.size += len(data) + + if self._in_memory: + self.file.write(data) + else: + await run_in_threadpool(self.file.write, data) + + async def read(self, size: int = -1) -> bytes: + if self._in_memory: + return self.file.read(size) + return await run_in_threadpool(self.file.read, size) + + async def seek(self, offset: int) -> None: + if self._in_memory: + self.file.seek(offset) + else: + await run_in_threadpool(self.file.seek, offset) + + async def close(self) -> None: + if self._in_memory: + self.file.close() + else: + await run_in_threadpool(self.file.close) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})" + + +class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]): + """ + An immutable multidict, containing both file uploads and text input. + """ + + def __init__( + self, + *args: FormData | typing.Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]], + **kwargs: str | UploadFile, + ) -> None: + super().__init__(*args, **kwargs) + + async def close(self) -> None: + for key, value in self.multi_items(): + if isinstance(value, UploadFile): + await value.close() + + +class Headers(typing.Mapping[str, str]): + """ + An immutable, case-insensitive multidict. + """ + + def __init__( + self, + headers: typing.Mapping[str, str] | None = None, + raw: list[tuple[bytes, bytes]] | None = None, + scope: typing.MutableMapping[str, typing.Any] | None = None, + ) -> None: + self._list: list[tuple[bytes, bytes]] = [] + if headers is not None: + assert raw is None, 'Cannot set both "headers" and "raw".' + assert scope is None, 'Cannot set both "headers" and "scope".' + self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()] + elif raw is not None: + assert scope is None, 'Cannot set both "raw" and "scope".' + self._list = raw + elif scope is not None: + # scope["headers"] isn't necessarily a list + # it might be a tuple or other iterable + self._list = scope["headers"] = list(scope["headers"]) + + @property + def raw(self) -> list[tuple[bytes, bytes]]: + return list(self._list) + + def keys(self) -> list[str]: # type: ignore[override] + return [key.decode("latin-1") for key, value in self._list] + + def values(self) -> list[str]: # type: ignore[override] + return [value.decode("latin-1") for key, value in self._list] + + def items(self) -> list[tuple[str, str]]: # type: ignore[override] + return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list] + + def getlist(self, key: str) -> list[str]: + get_header_key = key.lower().encode("latin-1") + return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key] + + def mutablecopy(self) -> MutableHeaders: + return MutableHeaders(raw=self._list[:]) + + def __getitem__(self, key: str) -> str: + get_header_key = key.lower().encode("latin-1") + for header_key, header_value in self._list: + if header_key == get_header_key: + return header_value.decode("latin-1") + raise KeyError(key) + + def __contains__(self, key: typing.Any) -> bool: + get_header_key = key.lower().encode("latin-1") + for header_key, header_value in self._list: + if header_key == get_header_key: + return True + return False + + def __iter__(self) -> typing.Iterator[typing.Any]: + return iter(self.keys()) + + def __len__(self) -> int: + return len(self._list) + + def __eq__(self, other: typing.Any) -> bool: + if not isinstance(other, Headers): + return False + return sorted(self._list) == sorted(other._list) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + as_dict = dict(self.items()) + if len(as_dict) == len(self): + return f"{class_name}({as_dict!r})" + return f"{class_name}(raw={self.raw!r})" + + +class MutableHeaders(Headers): + def __setitem__(self, key: str, value: str) -> None: + """ + Set the header `key` to `value`, removing any duplicate entries. + Retains insertion order. + """ + set_key = key.lower().encode("latin-1") + set_value = value.encode("latin-1") + + found_indexes: list[int] = [] + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == set_key: + found_indexes.append(idx) + + for idx in reversed(found_indexes[1:]): + del self._list[idx] + + if found_indexes: + idx = found_indexes[0] + self._list[idx] = (set_key, set_value) + else: + self._list.append((set_key, set_value)) + + def __delitem__(self, key: str) -> None: + """ + Remove the header `key`. + """ + del_key = key.lower().encode("latin-1") + + pop_indexes: list[int] = [] + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == del_key: + pop_indexes.append(idx) + + for idx in reversed(pop_indexes): + del self._list[idx] + + def __ior__(self, other: typing.Mapping[str, str]) -> MutableHeaders: + if not isinstance(other, typing.Mapping): + raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") + self.update(other) + return self + + def __or__(self, other: typing.Mapping[str, str]) -> MutableHeaders: + if not isinstance(other, typing.Mapping): + raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") + new = self.mutablecopy() + new.update(other) + return new + + @property + def raw(self) -> list[tuple[bytes, bytes]]: + return self._list + + def setdefault(self, key: str, value: str) -> str: + """ + If the header `key` does not exist, then set it to `value`. + Returns the header value. + """ + set_key = key.lower().encode("latin-1") + set_value = value.encode("latin-1") + + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == set_key: + return item_value.decode("latin-1") + self._list.append((set_key, set_value)) + return value + + def update(self, other: typing.Mapping[str, str]) -> None: + for key, val in other.items(): + self[key] = val + + def append(self, key: str, value: str) -> None: + """ + Append a header, preserving any duplicate entries. + """ + append_key = key.lower().encode("latin-1") + append_value = value.encode("latin-1") + self._list.append((append_key, append_value)) + + def add_vary_header(self, vary: str) -> None: + existing = self.get("vary") + if existing is not None: + vary = ", ".join([existing, vary]) + self["vary"] = vary + + +class State: + """ + An object that can be used to store arbitrary state. + + Used for `request.state` and `app.state`. + """ + + _state: dict[str, typing.Any] + + def __init__(self, state: dict[str, typing.Any] | None = None): + if state is None: + state = {} + super().__setattr__("_state", state) + + def __setattr__(self, key: typing.Any, value: typing.Any) -> None: + self._state[key] = value + + def __getattr__(self, key: typing.Any) -> typing.Any: + try: + return self._state[key] + except KeyError: + message = "'{}' object has no attribute '{}'" + raise AttributeError(message.format(self.__class__.__name__, key)) + + def __delattr__(self, key: typing.Any) -> None: + del self._state[key] |