aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/aiohttp/payload.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/aiohttp/payload.py')
-rw-r--r--.venv/lib/python3.12/site-packages/aiohttp/payload.py519
1 files changed, 519 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/aiohttp/payload.py b/.venv/lib/python3.12/site-packages/aiohttp/payload.py
new file mode 100644
index 00000000..3f6d3672
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/aiohttp/payload.py
@@ -0,0 +1,519 @@
+import asyncio
+import enum
+import io
+import json
+import mimetypes
+import os
+import sys
+import warnings
+from abc import ABC, abstractmethod
+from itertools import chain
+from typing import (
+ IO,
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Final,
+ Iterable,
+ Optional,
+ TextIO,
+ Tuple,
+ Type,
+ Union,
+)
+
+from multidict import CIMultiDict
+
+from . import hdrs
+from .abc import AbstractStreamWriter
+from .helpers import (
+ _SENTINEL,
+ content_disposition_header,
+ guess_filename,
+ parse_mimetype,
+ sentinel,
+)
+from .streams import StreamReader
+from .typedefs import JSONEncoder, _CIMultiDict
+
+__all__ = (
+ "PAYLOAD_REGISTRY",
+ "get_payload",
+ "payload_type",
+ "Payload",
+ "BytesPayload",
+ "StringPayload",
+ "IOBasePayload",
+ "BytesIOPayload",
+ "BufferedReaderPayload",
+ "TextIOPayload",
+ "StringIOPayload",
+ "JsonPayload",
+ "AsyncIterablePayload",
+)
+
+TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB
+
+if TYPE_CHECKING:
+ from typing import List
+
+
+class LookupError(Exception):
+ pass
+
+
+class Order(str, enum.Enum):
+ normal = "normal"
+ try_first = "try_first"
+ try_last = "try_last"
+
+
+def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload":
+ return PAYLOAD_REGISTRY.get(data, *args, **kwargs)
+
+
+def register_payload(
+ factory: Type["Payload"], type: Any, *, order: Order = Order.normal
+) -> None:
+ PAYLOAD_REGISTRY.register(factory, type, order=order)
+
+
+class payload_type:
+ def __init__(self, type: Any, *, order: Order = Order.normal) -> None:
+ self.type = type
+ self.order = order
+
+ def __call__(self, factory: Type["Payload"]) -> Type["Payload"]:
+ register_payload(factory, self.type, order=self.order)
+ return factory
+
+
+PayloadType = Type["Payload"]
+_PayloadRegistryItem = Tuple[PayloadType, Any]
+
+
+class PayloadRegistry:
+ """Payload registry.
+
+ note: we need zope.interface for more efficient adapter search
+ """
+
+ __slots__ = ("_first", "_normal", "_last", "_normal_lookup")
+
+ def __init__(self) -> None:
+ self._first: List[_PayloadRegistryItem] = []
+ self._normal: List[_PayloadRegistryItem] = []
+ self._last: List[_PayloadRegistryItem] = []
+ self._normal_lookup: Dict[Any, PayloadType] = {}
+
+ def get(
+ self,
+ data: Any,
+ *args: Any,
+ _CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain,
+ **kwargs: Any,
+ ) -> "Payload":
+ if self._first:
+ for factory, type_ in self._first:
+ if isinstance(data, type_):
+ return factory(data, *args, **kwargs)
+ # Try the fast lookup first
+ if lookup_factory := self._normal_lookup.get(type(data)):
+ return lookup_factory(data, *args, **kwargs)
+ # Bail early if its already a Payload
+ if isinstance(data, Payload):
+ return data
+ # Fallback to the slower linear search
+ for factory, type_ in _CHAIN(self._normal, self._last):
+ if isinstance(data, type_):
+ return factory(data, *args, **kwargs)
+ raise LookupError()
+
+ def register(
+ self, factory: PayloadType, type: Any, *, order: Order = Order.normal
+ ) -> None:
+ if order is Order.try_first:
+ self._first.append((factory, type))
+ elif order is Order.normal:
+ self._normal.append((factory, type))
+ if isinstance(type, Iterable):
+ for t in type:
+ self._normal_lookup[t] = factory
+ else:
+ self._normal_lookup[type] = factory
+ elif order is Order.try_last:
+ self._last.append((factory, type))
+ else:
+ raise ValueError(f"Unsupported order {order!r}")
+
+
+class Payload(ABC):
+
+ _default_content_type: str = "application/octet-stream"
+ _size: Optional[int] = None
+
+ def __init__(
+ self,
+ value: Any,
+ headers: Optional[
+ Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]]
+ ] = None,
+ content_type: Union[str, None, _SENTINEL] = sentinel,
+ filename: Optional[str] = None,
+ encoding: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ self._encoding = encoding
+ self._filename = filename
+ self._headers: _CIMultiDict = CIMultiDict()
+ self._value = value
+ if content_type is not sentinel and content_type is not None:
+ self._headers[hdrs.CONTENT_TYPE] = content_type
+ elif self._filename is not None:
+ if sys.version_info >= (3, 13):
+ guesser = mimetypes.guess_file_type
+ else:
+ guesser = mimetypes.guess_type
+ content_type = guesser(self._filename)[0]
+ if content_type is None:
+ content_type = self._default_content_type
+ self._headers[hdrs.CONTENT_TYPE] = content_type
+ else:
+ self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
+ if headers:
+ self._headers.update(headers)
+
+ @property
+ def size(self) -> Optional[int]:
+ """Size of the payload."""
+ return self._size
+
+ @property
+ def filename(self) -> Optional[str]:
+ """Filename of the payload."""
+ return self._filename
+
+ @property
+ def headers(self) -> _CIMultiDict:
+ """Custom item headers"""
+ return self._headers
+
+ @property
+ def _binary_headers(self) -> bytes:
+ return (
+ "".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode(
+ "utf-8"
+ )
+ + b"\r\n"
+ )
+
+ @property
+ def encoding(self) -> Optional[str]:
+ """Payload encoding"""
+ return self._encoding
+
+ @property
+ def content_type(self) -> str:
+ """Content type"""
+ return self._headers[hdrs.CONTENT_TYPE]
+
+ def set_content_disposition(
+ self,
+ disptype: str,
+ quote_fields: bool = True,
+ _charset: str = "utf-8",
+ **params: Any,
+ ) -> None:
+ """Sets ``Content-Disposition`` header."""
+ self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
+ disptype, quote_fields=quote_fields, _charset=_charset, **params
+ )
+
+ @abstractmethod
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ """Return string representation of the value.
+
+ This is named decode() to allow compatibility with bytes objects.
+ """
+
+ @abstractmethod
+ async def write(self, writer: AbstractStreamWriter) -> None:
+ """Write payload.
+
+ writer is an AbstractStreamWriter instance:
+ """
+
+
+class BytesPayload(Payload):
+ _value: bytes
+
+ def __init__(
+ self, value: Union[bytes, bytearray, memoryview], *args: Any, **kwargs: Any
+ ) -> None:
+ if "content_type" not in kwargs:
+ kwargs["content_type"] = "application/octet-stream"
+
+ super().__init__(value, *args, **kwargs)
+
+ if isinstance(value, memoryview):
+ self._size = value.nbytes
+ elif isinstance(value, (bytes, bytearray)):
+ self._size = len(value)
+ else:
+ raise TypeError(f"value argument must be byte-ish, not {type(value)!r}")
+
+ if self._size > TOO_LARGE_BYTES_BODY:
+ kwargs = {"source": self}
+ warnings.warn(
+ "Sending a large body directly with raw bytes might"
+ " lock the event loop. You should probably pass an "
+ "io.BytesIO object instead",
+ ResourceWarning,
+ **kwargs,
+ )
+
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ return self._value.decode(encoding, errors)
+
+ async def write(self, writer: AbstractStreamWriter) -> None:
+ await writer.write(self._value)
+
+
+class StringPayload(BytesPayload):
+ def __init__(
+ self,
+ value: str,
+ *args: Any,
+ encoding: Optional[str] = None,
+ content_type: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+
+ if encoding is None:
+ if content_type is None:
+ real_encoding = "utf-8"
+ content_type = "text/plain; charset=utf-8"
+ else:
+ mimetype = parse_mimetype(content_type)
+ real_encoding = mimetype.parameters.get("charset", "utf-8")
+ else:
+ if content_type is None:
+ content_type = "text/plain; charset=%s" % encoding
+ real_encoding = encoding
+
+ super().__init__(
+ value.encode(real_encoding),
+ encoding=real_encoding,
+ content_type=content_type,
+ *args,
+ **kwargs,
+ )
+
+
+class StringIOPayload(StringPayload):
+ def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None:
+ super().__init__(value.read(), *args, **kwargs)
+
+
+class IOBasePayload(Payload):
+ _value: io.IOBase
+
+ def __init__(
+ self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any
+ ) -> None:
+ if "filename" not in kwargs:
+ kwargs["filename"] = guess_filename(value)
+
+ super().__init__(value, *args, **kwargs)
+
+ if self._filename is not None and disposition is not None:
+ if hdrs.CONTENT_DISPOSITION not in self.headers:
+ self.set_content_disposition(disposition, filename=self._filename)
+
+ async def write(self, writer: AbstractStreamWriter) -> None:
+ loop = asyncio.get_event_loop()
+ try:
+ chunk = await loop.run_in_executor(None, self._value.read, 2**16)
+ while chunk:
+ await writer.write(chunk)
+ chunk = await loop.run_in_executor(None, self._value.read, 2**16)
+ finally:
+ await loop.run_in_executor(None, self._value.close)
+
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ return "".join(r.decode(encoding, errors) for r in self._value.readlines())
+
+
+class TextIOPayload(IOBasePayload):
+ _value: io.TextIOBase
+
+ def __init__(
+ self,
+ value: TextIO,
+ *args: Any,
+ encoding: Optional[str] = None,
+ content_type: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+
+ if encoding is None:
+ if content_type is None:
+ encoding = "utf-8"
+ content_type = "text/plain; charset=utf-8"
+ else:
+ mimetype = parse_mimetype(content_type)
+ encoding = mimetype.parameters.get("charset", "utf-8")
+ else:
+ if content_type is None:
+ content_type = "text/plain; charset=%s" % encoding
+
+ super().__init__(
+ value,
+ content_type=content_type,
+ encoding=encoding,
+ *args,
+ **kwargs,
+ )
+
+ @property
+ def size(self) -> Optional[int]:
+ try:
+ return os.fstat(self._value.fileno()).st_size - self._value.tell()
+ except OSError:
+ return None
+
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ return self._value.read()
+
+ async def write(self, writer: AbstractStreamWriter) -> None:
+ loop = asyncio.get_event_loop()
+ try:
+ chunk = await loop.run_in_executor(None, self._value.read, 2**16)
+ while chunk:
+ data = (
+ chunk.encode(encoding=self._encoding)
+ if self._encoding
+ else chunk.encode()
+ )
+ await writer.write(data)
+ chunk = await loop.run_in_executor(None, self._value.read, 2**16)
+ finally:
+ await loop.run_in_executor(None, self._value.close)
+
+
+class BytesIOPayload(IOBasePayload):
+ _value: io.BytesIO
+
+ @property
+ def size(self) -> int:
+ position = self._value.tell()
+ end = self._value.seek(0, os.SEEK_END)
+ self._value.seek(position)
+ return end - position
+
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ return self._value.read().decode(encoding, errors)
+
+
+class BufferedReaderPayload(IOBasePayload):
+ _value: io.BufferedIOBase
+
+ @property
+ def size(self) -> Optional[int]:
+ try:
+ return os.fstat(self._value.fileno()).st_size - self._value.tell()
+ except (OSError, AttributeError):
+ # data.fileno() is not supported, e.g.
+ # io.BufferedReader(io.BytesIO(b'data'))
+ # For some file-like objects (e.g. tarfile), the fileno() attribute may
+ # not exist at all, and will instead raise an AttributeError.
+ return None
+
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ return self._value.read().decode(encoding, errors)
+
+
+class JsonPayload(BytesPayload):
+ def __init__(
+ self,
+ value: Any,
+ encoding: str = "utf-8",
+ content_type: str = "application/json",
+ dumps: JSONEncoder = json.dumps,
+ *args: Any,
+ **kwargs: Any,
+ ) -> None:
+
+ super().__init__(
+ dumps(value).encode(encoding),
+ content_type=content_type,
+ encoding=encoding,
+ *args,
+ **kwargs,
+ )
+
+
+if TYPE_CHECKING:
+ from typing import AsyncIterable, AsyncIterator
+
+ _AsyncIterator = AsyncIterator[bytes]
+ _AsyncIterable = AsyncIterable[bytes]
+else:
+ from collections.abc import AsyncIterable, AsyncIterator
+
+ _AsyncIterator = AsyncIterator
+ _AsyncIterable = AsyncIterable
+
+
+class AsyncIterablePayload(Payload):
+
+ _iter: Optional[_AsyncIterator] = None
+ _value: _AsyncIterable
+
+ def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None:
+ if not isinstance(value, AsyncIterable):
+ raise TypeError(
+ "value argument must support "
+ "collections.abc.AsyncIterable interface, "
+ "got {!r}".format(type(value))
+ )
+
+ if "content_type" not in kwargs:
+ kwargs["content_type"] = "application/octet-stream"
+
+ super().__init__(value, *args, **kwargs)
+
+ self._iter = value.__aiter__()
+
+ async def write(self, writer: AbstractStreamWriter) -> None:
+ if self._iter:
+ try:
+ # iter is not None check prevents rare cases
+ # when the case iterable is used twice
+ while True:
+ chunk = await self._iter.__anext__()
+ await writer.write(chunk)
+ except StopAsyncIteration:
+ self._iter = None
+
+ def decode(self, encoding: str = "utf-8", errors: str = "strict") -> str:
+ raise TypeError("Unable to decode.")
+
+
+class StreamReaderPayload(AsyncIterablePayload):
+ def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:
+ super().__init__(value.iter_any(), *args, **kwargs)
+
+
+PAYLOAD_REGISTRY = PayloadRegistry()
+PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview))
+PAYLOAD_REGISTRY.register(StringPayload, str)
+PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO)
+PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase)
+PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO)
+PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom))
+PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase)
+PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader)
+# try_last for giving a chance to more specialized async interables like
+# multidict.BodyPartReaderPayload override the default
+PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last)