about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/aiohttp/payload.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/aiohttp/payload.py')
-rw-r--r--.venv/lib/python3.12/site-packages/aiohttp/payload.py519
1 files changed, 519 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/aiohttp/payload.py b/.venv/lib/python3.12/site-packages/aiohttp/payload.py
new file mode 100644
index 00000000..3f6d3672
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/aiohttp/payload.py
@@ -0,0 +1,519 @@
+import asyncio
+import enum
+import io
+import json
+import mimetypes
+import os
+import sys
+import warnings
+from abc import ABC, abstractmethod
+from itertools import chain
+from typing import (
+    IO,
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Final,
+    Iterable,
+    Optional,
+    TextIO,
+    Tuple,
+    Type,
+    Union,
+)
+
+from multidict import CIMultiDict
+
+from . import hdrs
+from .abc import AbstractStreamWriter
+from .helpers import (
+    _SENTINEL,
+    content_disposition_header,
+    guess_filename,
+    parse_mimetype,
+    sentinel,
+)
+from .streams import StreamReader
+from .typedefs import JSONEncoder, _CIMultiDict
+
+__all__ = (
+    "PAYLOAD_REGISTRY",
+    "get_payload",
+    "payload_type",
+    "Payload",
+    "BytesPayload",
+    "StringPayload",
+    "IOBasePayload",
+    "BytesIOPayload",
+    "BufferedReaderPayload",
+    "TextIOPayload",
+    "StringIOPayload",
+    "JsonPayload",
+    "AsyncIterablePayload",
+)
+
+TOO_LARGE_BYTES_BODY: Final[int] = 2**20  # 1 MB
+
+if TYPE_CHECKING:
+    from typing import List
+
+
+class LookupError(Exception):
+    pass
+
+
+class Order(str, enum.Enum):
+    normal = "normal"
+    try_first = "try_first"
+    try_last = "try_last"
+
+
+def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload":
+    return PAYLOAD_REGISTRY.get(data, *args, **kwargs)
+
+
+def register_payload(
+    factory: Type["Payload"], type: Any, *, order: Order = Order.normal
+) -> None:
+    PAYLOAD_REGISTRY.register(factory, type, order=order)
+
+
+class payload_type:
+    def __init__(self, type: Any, *, order: Order = Order.normal) -> None:
+        self.type = type
+        self.order = order
+
+    def __call__(self, factory: Type["Payload"]) -> Type["Payload"]:
+        register_payload(factory, self.type, order=self.order)
+        return factory
+
+
+PayloadType = Type["Payload"]
+_PayloadRegistryItem = Tuple[PayloadType, Any]
+
+
+class PayloadRegistry:
+    """Payload registry.
+
+    note: we need zope.interface for more efficient adapter search
+    """
+
+    __slots__ = ("_first", "_normal", "_last", "_normal_lookup")
+
+    def __init__(self) -> None:
+        self._first: List[_PayloadRegistryItem] = []
+        self._normal: List[_PayloadRegistryItem] = []
+        self._last: List[_PayloadRegistryItem] = []
+        self._normal_lookup: Dict[Any, PayloadType] = {}
+
+    def get(
+        self,
+        data: Any,
+        *args: Any,
+        _CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain,
+        **kwargs: Any,
+    ) -> "Payload":
+        if self._first:
+            for factory, type_ in self._first:
+                if isinstance(data, type_):
+                    return factory(data, *args, **kwargs)
+        # Try the fast lookup first
+        if lookup_factory := self._normal_lookup.get(type(data)):
+            return lookup_factory(data, *args, **kwargs)
+        # Bail early if its already a Payload
+        if isinstance(data, Payload):
+            return data
+        # Fallback to the slower linear search
+        for factory, type_ in _CHAIN(self._normal, self._last):
+            if isinstance(data, type_):
+                return factory(data, *args, **kwargs)
+        raise LookupError()
+
+    def register(
+        self, factory: PayloadType, type: Any, *, order: Order = Order.normal
+    ) -> None:
+        if order is Order.try_first:
+            self._first.append((factory, type))
+        elif order is Order.normal:
+            self._normal.append((factory, type))
+            if isinstance(type, Iterable):
+                for t in type:
+                    self._normal_lookup[t] = factory
+            else:
+                self._normal_lookup[type] = factory
+        elif order is Order.try_last:
+            self._last.append((factory, type))
+        else:
+            raise ValueError(f"Unsupported order {order!r}")
+
+
+class Payload(ABC):
+
+    _default_content_type: str = "application/octet-stream"
+    _size: Optional[int] = None
+
+    def __init__(
+        self,
+        value: Any,
+        headers: Optional[
+            Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]]
+        ] = None,
+        content_type: Union[str, None, _SENTINEL] = sentinel,
+        filename: Optional[str] = None,
+        encoding: Optional[str] = None,
+        **kwargs: Any,
+    ) -> None:
+        self._encoding = encoding
+        self._filename = filename
+        self._headers: _CIMultiDict = CIMultiDict()
+        self._value = value
+        if content_type is not sentinel and content_type is not None:
+            self._headers[hdrs.CONTENT_TYPE] = content_type
+        elif self._filename is not None:
+            if sys.version_info >= (3, 13):
+                guesser = mimetypes.guess_file_type
+            else:
+                guesser = mimetypes.guess_type
+            content_type = guesser(self._filename)[0]
+            if content_type is None:
+                content_type = self._default_content_type
+            self._headers[hdrs.CONTENT_TYPE] = content_type
+        else:
+            self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
+        if headers:
+            self._headers.update(headers)
+
+    @property
+    def size(self) -> Optional[int]:
+        """Size of the payload."""
+        return self._size
+
+    @property
+    def filename(self) -> Optional[str]:
+        """Filename of the payload."""
+        return self._filename
+
+    @property
+    def headers(self) -> _CIMultiDict:
+        """Custom item headers"""
+        return self._headers
+
+    @property
+    def _binary_headers(self) -> bytes:
+        return (
+            "".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode(
+                "utf-8"
+            )
+            + b"\r\n"
+        )
+
+    @property
+    def encoding(self) -> Optional[str]:
+        """Payload encoding"""
+        return self._encoding
+
+    @property
+    def content_type(self) -> str:
+        """Content type"""
+        return self._headers[hdrs.CONTENT_TYPE]
+
+    def set_content_disposition(
+        self,
+        disptype: str,
+        quote_fields: bool = True,
+        _charset: str = "utf-8",
+        **params: Any,
+    ) -> None:
+        """Sets ``Content-Disposition`` header."""
+        self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
+            disptype, quote_fields=quote_fields, _charset=_charset, **params
+        )
+
+    @abstractmethod
+    def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+        """Return string representation of the value.
+
+        This is named decode() to allow compatibility with bytes objects.
+        """
+
+    @abstractmethod
+    async def write(self, writer: AbstractStreamWriter) -> None:
+        """Write payload.
+
+        writer is an AbstractStreamWriter instance:
+        """
+
+
+class BytesPayload(Payload):
+    _value: bytes
+
+    def __init__(
+        self, value: Union[bytes, bytearray, memoryview], *args: Any, **kwargs: Any
+    ) -> None:
+        if "content_type" not in kwargs:
+            kwargs["content_type"] = "application/octet-stream"
+
+        super().__init__(value, *args, **kwargs)
+
+        if isinstance(value, memoryview):
+            self._size = value.nbytes
+        elif isinstance(value, (bytes, bytearray)):
+            self._size = len(value)
+        else:
+            raise TypeError(f"value argument must be byte-ish, not {type(value)!r}")
+
+        if self._size > TOO_LARGE_BYTES_BODY:
+            kwargs = {"source": self}
+            warnings.warn(
+                "Sending a large body directly with raw bytes might"
+                " lock the event loop. You should probably pass an "
+                "io.BytesIO object instead",
+                ResourceWarning,
+                **kwargs,
+            )
+
+    def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+        return self._value.decode(encoding, errors)
+
+    async def write(self, writer: AbstractStreamWriter) -> None:
+        await writer.write(self._value)
+
+
+class StringPayload(BytesPayload):
+    def __init__(
+        self,
+        value: str,
+        *args: Any,
+        encoding: Optional[str] = None,
+        content_type: Optional[str] = None,
+        **kwargs: Any,
+    ) -> None:
+
+        if encoding is None:
+            if content_type is None:
+                real_encoding = "utf-8"
+                content_type = "text/plain; charset=utf-8"
+            else:
+                mimetype = parse_mimetype(content_type)
+                real_encoding = mimetype.parameters.get("charset", "utf-8")
+        else:
+            if content_type is None:
+                content_type = "text/plain; charset=%s" % encoding
+            real_encoding = encoding
+
+        super().__init__(
+            value.encode(real_encoding),
+            encoding=real_encoding,
+            content_type=content_type,
+            *args,
+            **kwargs,
+        )
+
+
+class StringIOPayload(StringPayload):
+    def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None:
+        super().__init__(value.read(), *args, **kwargs)
+
+
+class IOBasePayload(Payload):
+    _value: io.IOBase
+
+    def __init__(
+        self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any
+    ) -> None:
+        if "filename" not in kwargs:
+            kwargs["filename"] = guess_filename(value)
+
+        super().__init__(value, *args, **kwargs)
+
+        if self._filename is not None and disposition is not None:
+            if hdrs.CONTENT_DISPOSITION not in self.headers:
+                self.set_content_disposition(disposition, filename=self._filename)
+
+    async def write(self, writer: AbstractStreamWriter) -> None:
+        loop = asyncio.get_event_loop()
+        try:
+            chunk = await loop.run_in_executor(None, self._value.read, 2**16)
+            while chunk:
+                await writer.write(chunk)
+                chunk = await loop.run_in_executor(None, self._value.read, 2**16)
+        finally:
+            await loop.run_in_executor(None, self._value.close)
+
+    def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+        return "".join(r.decode(encoding, errors) for r in self._value.readlines())
+
+
+class TextIOPayload(IOBasePayload):
+    _value: io.TextIOBase
+
+    def __init__(
+        self,
+        value: TextIO,
+        *args: Any,
+        encoding: Optional[str] = None,
+        content_type: Optional[str] = None,
+        **kwargs: Any,
+    ) -> None:
+
+        if encoding is None:
+            if content_type is None:
+                encoding = "utf-8"
+                content_type = "text/plain; charset=utf-8"
+            else:
+                mimetype = parse_mimetype(content_type)
+                encoding = mimetype.parameters.get("charset", "utf-8")
+        else:
+            if content_type is None:
+                content_type = "text/plain; charset=%s" % encoding
+
+        super().__init__(
+            value,
+            content_type=content_type,
+            encoding=encoding,
+            *args,
+            **kwargs,
+        )
+
+    @property
+    def size(self) -> Optional[int]:
+        try:
+            return os.fstat(self._value.fileno()).st_size - self._value.tell()
+        except OSError:
+            return None
+
+    def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+        return self._value.read()
+
+    async def write(self, writer: AbstractStreamWriter) -> None:
+        loop = asyncio.get_event_loop()
+        try:
+            chunk = await loop.run_in_executor(None, self._value.read, 2**16)
+            while chunk:
+                data = (
+                    chunk.encode(encoding=self._encoding)
+                    if self._encoding
+                    else chunk.encode()
+                )
+                await writer.write(data)
+                chunk = await loop.run_in_executor(None, self._value.read, 2**16)
+        finally:
+            await loop.run_in_executor(None, self._value.close)
+
+
+class BytesIOPayload(IOBasePayload):
+    _value: io.BytesIO
+
+    @property
+    def size(self) -> int:
+        position = self._value.tell()
+        end = self._value.seek(0, os.SEEK_END)
+        self._value.seek(position)
+        return end - position
+
+    def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+        return self._value.read().decode(encoding, errors)
+
+
+class BufferedReaderPayload(IOBasePayload):
+    _value: io.BufferedIOBase
+
+    @property
+    def size(self) -> Optional[int]:
+        try:
+            return os.fstat(self._value.fileno()).st_size - self._value.tell()
+        except (OSError, AttributeError):
+            # data.fileno() is not supported, e.g.
+            # io.BufferedReader(io.BytesIO(b'data'))
+            # For some file-like objects (e.g. tarfile), the fileno() attribute may
+            # not exist at all, and will instead raise an AttributeError.
+            return None
+
+    def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+        return self._value.read().decode(encoding, errors)
+
+
+class JsonPayload(BytesPayload):
+    def __init__(
+        self,
+        value: Any,
+        encoding: str = "utf-8",
+        content_type: str = "application/json",
+        dumps: JSONEncoder = json.dumps,
+        *args: Any,
+        **kwargs: Any,
+    ) -> None:
+
+        super().__init__(
+            dumps(value).encode(encoding),
+            content_type=content_type,
+            encoding=encoding,
+            *args,
+            **kwargs,
+        )
+
+
+if TYPE_CHECKING:
+    from typing import AsyncIterable, AsyncIterator
+
+    _AsyncIterator = AsyncIterator[bytes]
+    _AsyncIterable = AsyncIterable[bytes]
+else:
+    from collections.abc import AsyncIterable, AsyncIterator
+
+    _AsyncIterator = AsyncIterator
+    _AsyncIterable = AsyncIterable
+
+
+class AsyncIterablePayload(Payload):
+
+    _iter: Optional[_AsyncIterator] = None
+    _value: _AsyncIterable
+
+    def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None:
+        if not isinstance(value, AsyncIterable):
+            raise TypeError(
+                "value argument must support "
+                "collections.abc.AsyncIterable interface, "
+                "got {!r}".format(type(value))
+            )
+
+        if "content_type" not in kwargs:
+            kwargs["content_type"] = "application/octet-stream"
+
+        super().__init__(value, *args, **kwargs)
+
+        self._iter = value.__aiter__()
+
+    async def write(self, writer: AbstractStreamWriter) -> None:
+        if self._iter:
+            try:
+                # iter is not None check prevents rare cases
+                # when the case iterable is used twice
+                while True:
+                    chunk = await self._iter.__anext__()
+                    await writer.write(chunk)
+            except StopAsyncIteration:
+                self._iter = None
+
+    def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+        raise TypeError("Unable to decode.")
+
+
+class StreamReaderPayload(AsyncIterablePayload):
+    def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:
+        super().__init__(value.iter_any(), *args, **kwargs)
+
+
+PAYLOAD_REGISTRY = PayloadRegistry()
+PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview))
+PAYLOAD_REGISTRY.register(StringPayload, str)
+PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO)
+PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase)
+PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO)
+PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom))
+PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase)
+PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader)
+# try_last for giving a chance to more specialized async interables like
+# multidict.BodyPartReaderPayload override the default
+PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last)