diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/aiohttp/_websocket')
15 files changed, 1534 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/aiohttp/_websocket/.hash/mask.pxd.hash b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/.hash/mask.pxd.hash new file mode 100644 index 00000000..eadfed3d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/.hash/mask.pxd.hash @@ -0,0 +1 @@ +b01999d409b29bd916e067bc963d5f2d9ee63cfc9ae0bccb769910131417bf93 /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/mask.pxd diff --git a/.venv/lib/python3.12/site-packages/aiohttp/_websocket/.hash/mask.pyx.hash b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/.hash/mask.pyx.hash new file mode 100644 index 00000000..5cd7ae67 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/.hash/mask.pyx.hash @@ -0,0 +1 @@ +0478ceb55d0ed30ef1a7da742cd003449bc69a07cf9fdb06789bd2b347cbfffe /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/mask.pyx diff --git a/.venv/lib/python3.12/site-packages/aiohttp/_websocket/.hash/reader_c.pxd.hash b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/.hash/reader_c.pxd.hash new file mode 100644 index 00000000..ff743553 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/.hash/reader_c.pxd.hash @@ -0,0 +1 @@ +f6b3160a9002d639e0eff82da8b8d196a42ff6aed490e9faded2107eada4f067 /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/reader_c.pxd diff --git a/.venv/lib/python3.12/site-packages/aiohttp/_websocket/__init__.py b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/__init__.py new file mode 100644 index 00000000..836257cc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/__init__.py @@ -0,0 +1 @@ +"""WebSocket protocol versions 13 and 8.""" diff --git a/.venv/lib/python3.12/site-packages/aiohttp/_websocket/helpers.py b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/helpers.py new file mode 100644 index 00000000..0bb58df9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/helpers.py @@ -0,0 +1,147 @@ +"""Helpers for WebSocket protocol versions 13 and 8.""" + +import functools +import re +from struct import Struct +from typing import TYPE_CHECKING, Final, List, Optional, Pattern, Tuple + +from ..helpers import NO_EXTENSIONS +from .models import WSHandshakeError + +UNPACK_LEN3 = Struct("!Q").unpack_from +UNPACK_CLOSE_CODE = Struct("!H").unpack +PACK_LEN1 = Struct("!BB").pack +PACK_LEN2 = Struct("!BBH").pack +PACK_LEN3 = Struct("!BBQ").pack +PACK_CLOSE_CODE = Struct("!H").pack +PACK_RANDBITS = Struct("!L").pack +MSG_SIZE: Final[int] = 2**14 +MASK_LEN: Final[int] = 4 + +WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +# Used by _websocket_mask_python +@functools.lru_cache +def _xor_table() -> List[bytes]: + return [bytes(a ^ b for a in range(256)) for b in range(256)] + + +def _websocket_mask_python(mask: bytes, data: bytearray) -> None: + """Websocket masking function. + + `mask` is a `bytes` object of length 4; `data` is a `bytearray` + object of any length. The contents of `data` are masked with `mask`, + as specified in section 5.3 of RFC 6455. + + Note that this function mutates the `data` argument. + + This pure-python implementation may be replaced by an optimized + version when available. + + """ + assert isinstance(data, bytearray), data + assert len(mask) == 4, mask + + if data: + _XOR_TABLE = _xor_table() + a, b, c, d = (_XOR_TABLE[n] for n in mask) + data[::4] = data[::4].translate(a) + data[1::4] = data[1::4].translate(b) + data[2::4] = data[2::4].translate(c) + data[3::4] = data[3::4].translate(d) + + +if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover + websocket_mask = _websocket_mask_python +else: + try: + from .mask import _websocket_mask_cython # type: ignore[import-not-found] + + websocket_mask = _websocket_mask_cython + except ImportError: # pragma: no cover + websocket_mask = _websocket_mask_python + + +_WS_EXT_RE: Final[Pattern[str]] = re.compile( + r"^(?:;\s*(?:" + r"(server_no_context_takeover)|" + r"(client_no_context_takeover)|" + r"(server_max_window_bits(?:=(\d+))?)|" + r"(client_max_window_bits(?:=(\d+))?)))*$" +) + +_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?") + + +def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]: + if not extstr: + return 0, False + + compress = 0 + notakeover = False + for ext in _WS_EXT_RE_SPLIT.finditer(extstr): + defext = ext.group(1) + # Return compress = 15 when get `permessage-deflate` + if not defext: + compress = 15 + break + match = _WS_EXT_RE.match(defext) + if match: + compress = 15 + if isserver: + # Server never fail to detect compress handshake. + # Server does not need to send max wbit to client + if match.group(4): + compress = int(match.group(4)) + # Group3 must match if group4 matches + # Compress wbit 8 does not support in zlib + # If compress level not support, + # CONTINUE to next extension + if compress > 15 or compress < 9: + compress = 0 + continue + if match.group(1): + notakeover = True + # Ignore regex group 5 & 6 for client_max_window_bits + break + else: + if match.group(6): + compress = int(match.group(6)) + # Group5 must match if group6 matches + # Compress wbit 8 does not support in zlib + # If compress level not support, + # FAIL the parse progress + if compress > 15 or compress < 9: + raise WSHandshakeError("Invalid window size") + if match.group(2): + notakeover = True + # Ignore regex group 5 & 6 for client_max_window_bits + break + # Return Fail if client side and not match + elif not isserver: + raise WSHandshakeError("Extension for deflate not supported" + ext.group(1)) + + return compress, notakeover + + +def ws_ext_gen( + compress: int = 15, isserver: bool = False, server_notakeover: bool = False +) -> str: + # client_notakeover=False not used for server + # compress wbit 8 does not support in zlib + if compress < 9 or compress > 15: + raise ValueError( + "Compress wbits must between 9 and 15, zlib does not support wbits=8" + ) + enabledext = ["permessage-deflate"] + if not isserver: + enabledext.append("client_max_window_bits") + + if compress < 15: + enabledext.append("server_max_window_bits=" + str(compress)) + if server_notakeover: + enabledext.append("server_no_context_takeover") + # if client_notakeover: + # enabledext.append('client_no_context_takeover') + return "; ".join(enabledext) diff --git a/.venv/lib/python3.12/site-packages/aiohttp/_websocket/mask.cpython-312-x86_64-linux-gnu.so b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/mask.cpython-312-x86_64-linux-gnu.so Binary files differnew file mode 100755 index 00000000..55ede5c2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/mask.cpython-312-x86_64-linux-gnu.so diff --git a/.venv/lib/python3.12/site-packages/aiohttp/_websocket/mask.pxd b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/mask.pxd new file mode 100644 index 00000000..90983de9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/mask.pxd @@ -0,0 +1,3 @@ +"""Cython declarations for websocket masking.""" + +cpdef void _websocket_mask_cython(bytes mask, bytearray data) diff --git a/.venv/lib/python3.12/site-packages/aiohttp/_websocket/mask.pyx b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/mask.pyx new file mode 100644 index 00000000..2d956c88 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/mask.pyx @@ -0,0 +1,48 @@ +from cpython cimport PyBytes_AsString + + +#from cpython cimport PyByteArray_AsString # cython still not exports that +cdef extern from "Python.h": + char* PyByteArray_AsString(bytearray ba) except NULL + +from libc.stdint cimport uint32_t, uint64_t, uintmax_t + + +cpdef void _websocket_mask_cython(bytes mask, bytearray data): + """Note, this function mutates its `data` argument + """ + cdef: + Py_ssize_t data_len, i + # bit operations on signed integers are implementation-specific + unsigned char * in_buf + const unsigned char * mask_buf + uint32_t uint32_msk + uint64_t uint64_msk + + assert len(mask) == 4 + + data_len = len(data) + in_buf = <unsigned char*>PyByteArray_AsString(data) + mask_buf = <const unsigned char*>PyBytes_AsString(mask) + uint32_msk = (<uint32_t*>mask_buf)[0] + + # TODO: align in_data ptr to achieve even faster speeds + # does it need in python ?! malloc() always aligns to sizeof(long) bytes + + if sizeof(size_t) >= 8: + uint64_msk = uint32_msk + uint64_msk = (uint64_msk << 32) | uint32_msk + + while data_len >= 8: + (<uint64_t*>in_buf)[0] ^= uint64_msk + in_buf += 8 + data_len -= 8 + + + while data_len >= 4: + (<uint32_t*>in_buf)[0] ^= uint32_msk + in_buf += 4 + data_len -= 4 + + for i in range(0, data_len): + in_buf[i] ^= mask_buf[i] diff --git a/.venv/lib/python3.12/site-packages/aiohttp/_websocket/models.py b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/models.py new file mode 100644 index 00000000..7e89b965 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/models.py @@ -0,0 +1,84 @@ +"""Models for WebSocket protocol versions 13 and 8.""" + +import json +from enum import IntEnum +from typing import Any, Callable, Final, NamedTuple, Optional, cast + +WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF]) + + +class WSCloseCode(IntEnum): + OK = 1000 + GOING_AWAY = 1001 + PROTOCOL_ERROR = 1002 + UNSUPPORTED_DATA = 1003 + ABNORMAL_CLOSURE = 1006 + INVALID_TEXT = 1007 + POLICY_VIOLATION = 1008 + MESSAGE_TOO_BIG = 1009 + MANDATORY_EXTENSION = 1010 + INTERNAL_ERROR = 1011 + SERVICE_RESTART = 1012 + TRY_AGAIN_LATER = 1013 + BAD_GATEWAY = 1014 + + +class WSMsgType(IntEnum): + # websocket spec types + CONTINUATION = 0x0 + TEXT = 0x1 + BINARY = 0x2 + PING = 0x9 + PONG = 0xA + CLOSE = 0x8 + + # aiohttp specific types + CLOSING = 0x100 + CLOSED = 0x101 + ERROR = 0x102 + + text = TEXT + binary = BINARY + ping = PING + pong = PONG + close = CLOSE + closing = CLOSING + closed = CLOSED + error = ERROR + + +class WSMessage(NamedTuple): + type: WSMsgType + # To type correctly, this would need some kind of tagged union for each type. + data: Any + extra: Optional[str] + + def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any: + """Return parsed JSON data. + + .. versionadded:: 0.22 + """ + return loads(self.data) + + +# Constructing the tuple directly to avoid the overhead of +# the lambda and arg processing since NamedTuples are constructed +# with a run time built lambda +# https://github.com/python/cpython/blob/d83fcf8371f2f33c7797bc8f5423a8bca8c46e5c/Lib/collections/__init__.py#L441 +WS_CLOSED_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSED, None, None)) +WS_CLOSING_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSING, None, None)) + + +class WebSocketError(Exception): + """WebSocket protocol parser error.""" + + def __init__(self, code: int, message: str) -> None: + self.code = code + super().__init__(code, message) + + def __str__(self) -> str: + return cast(str, self.args[1]) + + +class WSHandshakeError(Exception): + """WebSocket protocol handshake error.""" diff --git a/.venv/lib/python3.12/site-packages/aiohttp/_websocket/reader.py b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/reader.py new file mode 100644 index 00000000..23f32265 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/reader.py @@ -0,0 +1,31 @@ +"""Reader for WebSocket protocol versions 13 and 8.""" + +from typing import TYPE_CHECKING + +from ..helpers import NO_EXTENSIONS + +if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover + from .reader_py import ( + WebSocketDataQueue as WebSocketDataQueuePython, + WebSocketReader as WebSocketReaderPython, + ) + + WebSocketReader = WebSocketReaderPython + WebSocketDataQueue = WebSocketDataQueuePython +else: + try: + from .reader_c import ( # type: ignore[import-not-found] + WebSocketDataQueue as WebSocketDataQueueCython, + WebSocketReader as WebSocketReaderCython, + ) + + WebSocketReader = WebSocketReaderCython + WebSocketDataQueue = WebSocketDataQueueCython + except ImportError: # pragma: no cover + from .reader_py import ( + WebSocketDataQueue as WebSocketDataQueuePython, + WebSocketReader as WebSocketReaderPython, + ) + + WebSocketReader = WebSocketReaderPython + WebSocketDataQueue = WebSocketDataQueuePython diff --git a/.venv/lib/python3.12/site-packages/aiohttp/_websocket/reader_c.cpython-312-x86_64-linux-gnu.so b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/reader_c.cpython-312-x86_64-linux-gnu.so Binary files differnew file mode 100755 index 00000000..98363ce7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/reader_c.cpython-312-x86_64-linux-gnu.so diff --git a/.venv/lib/python3.12/site-packages/aiohttp/_websocket/reader_c.pxd b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/reader_c.pxd new file mode 100644 index 00000000..461e658e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/reader_c.pxd @@ -0,0 +1,102 @@ +import cython + +from .mask cimport _websocket_mask_cython as websocket_mask + + +cdef unsigned int READ_HEADER +cdef unsigned int READ_PAYLOAD_LENGTH +cdef unsigned int READ_PAYLOAD_MASK +cdef unsigned int READ_PAYLOAD + +cdef unsigned int OP_CODE_CONTINUATION +cdef unsigned int OP_CODE_TEXT +cdef unsigned int OP_CODE_BINARY +cdef unsigned int OP_CODE_CLOSE +cdef unsigned int OP_CODE_PING +cdef unsigned int OP_CODE_PONG + +cdef object UNPACK_LEN3 +cdef object UNPACK_CLOSE_CODE +cdef object TUPLE_NEW + +cdef object WSMsgType +cdef object WSMessage + +cdef object WS_MSG_TYPE_TEXT +cdef object WS_MSG_TYPE_BINARY + +cdef set ALLOWED_CLOSE_CODES +cdef set MESSAGE_TYPES_WITH_CONTENT + +cdef tuple EMPTY_FRAME +cdef tuple EMPTY_FRAME_ERROR + +cdef class WebSocketDataQueue: + + cdef unsigned int _size + cdef public object _protocol + cdef unsigned int _limit + cdef object _loop + cdef bint _eof + cdef object _waiter + cdef object _exception + cdef public object _buffer + cdef object _get_buffer + cdef object _put_buffer + + cdef void _release_waiter(self) + + cpdef void feed_data(self, object data, unsigned int size) + + @cython.locals(size="unsigned int") + cdef _read_from_buffer(self) + +cdef class WebSocketReader: + + cdef WebSocketDataQueue queue + cdef unsigned int _max_msg_size + + cdef Exception _exc + cdef bytearray _partial + cdef unsigned int _state + + cdef object _opcode + cdef object _frame_fin + cdef object _frame_opcode + cdef object _frame_payload + cdef unsigned long long _frame_payload_len + + cdef bytes _tail + cdef bint _has_mask + cdef bytes _frame_mask + cdef unsigned long long _payload_length + cdef unsigned int _payload_length_flag + cdef object _compressed + cdef object _decompressobj + cdef bint _compress + + cpdef tuple feed_data(self, object data) + + @cython.locals( + is_continuation=bint, + fin=bint, + has_partial=bint, + payload_merged=bytes, + opcode="unsigned int", + ) + cpdef void _feed_data(self, bytes data) + + @cython.locals( + start_pos="unsigned int", + buf_len="unsigned int", + length="unsigned int", + chunk_size="unsigned int", + chunk_len="unsigned int", + buf_length="unsigned int", + first_byte="unsigned char", + second_byte="unsigned char", + end_pos="unsigned int", + has_mask=bint, + fin=bint, + ) + cpdef list parse_frame(self, bytes buf) diff --git a/.venv/lib/python3.12/site-packages/aiohttp/_websocket/reader_c.py b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/reader_c.py new file mode 100644 index 00000000..1645b394 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/reader_c.py @@ -0,0 +1,469 @@ +"""Reader for WebSocket protocol versions 13 and 8.""" + +import asyncio +import builtins +from collections import deque +from typing import Deque, Final, List, Optional, Set, Tuple, Union + +from ..base_protocol import BaseProtocol +from ..compression_utils import ZLibDecompressor +from ..helpers import _EXC_SENTINEL, set_exception +from ..streams import EofStream +from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask +from .models import ( + WS_DEFLATE_TRAILING, + WebSocketError, + WSCloseCode, + WSMessage, + WSMsgType, +) + +ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode} + +# States for the reader, used to parse the WebSocket frame +# integer values are used so they can be cythonized +READ_HEADER = 1 +READ_PAYLOAD_LENGTH = 2 +READ_PAYLOAD_MASK = 3 +READ_PAYLOAD = 4 + +WS_MSG_TYPE_BINARY = WSMsgType.BINARY +WS_MSG_TYPE_TEXT = WSMsgType.TEXT + +# WSMsgType values unpacked so they can by cythonized to ints +OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value +OP_CODE_TEXT = WSMsgType.TEXT.value +OP_CODE_BINARY = WSMsgType.BINARY.value +OP_CODE_CLOSE = WSMsgType.CLOSE.value +OP_CODE_PING = WSMsgType.PING.value +OP_CODE_PONG = WSMsgType.PONG.value + +EMPTY_FRAME_ERROR = (True, b"") +EMPTY_FRAME = (False, b"") + +TUPLE_NEW = tuple.__new__ + +int_ = int # Prevent Cython from converting to PyInt + + +class WebSocketDataQueue: + """WebSocketDataQueue resumes and pauses an underlying stream. + + It is a destination for WebSocket data. + """ + + def __init__( + self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop + ) -> None: + self._size = 0 + self._protocol = protocol + self._limit = limit * 2 + self._loop = loop + self._eof = False + self._waiter: Optional[asyncio.Future[None]] = None + self._exception: Union[BaseException, None] = None + self._buffer: Deque[Tuple[WSMessage, int]] = deque() + self._get_buffer = self._buffer.popleft + self._put_buffer = self._buffer.append + + def is_eof(self) -> bool: + return self._eof + + def exception(self) -> Optional[BaseException]: + return self._exception + + def set_exception( + self, + exc: "BaseException", + exc_cause: builtins.BaseException = _EXC_SENTINEL, + ) -> None: + self._eof = True + self._exception = exc + if (waiter := self._waiter) is not None: + self._waiter = None + set_exception(waiter, exc, exc_cause) + + def _release_waiter(self) -> None: + if (waiter := self._waiter) is None: + return + self._waiter = None + if not waiter.done(): + waiter.set_result(None) + + def feed_eof(self) -> None: + self._eof = True + self._release_waiter() + self._exception = None # Break cyclic references + + def feed_data(self, data: "WSMessage", size: "int_") -> None: + self._size += size + self._put_buffer((data, size)) + self._release_waiter() + if self._size > self._limit and not self._protocol._reading_paused: + self._protocol.pause_reading() + + async def read(self) -> WSMessage: + if not self._buffer and not self._eof: + assert not self._waiter + self._waiter = self._loop.create_future() + try: + await self._waiter + except (asyncio.CancelledError, asyncio.TimeoutError): + self._waiter = None + raise + return self._read_from_buffer() + + def _read_from_buffer(self) -> WSMessage: + if self._buffer: + data, size = self._get_buffer() + self._size -= size + if self._size < self._limit and self._protocol._reading_paused: + self._protocol.resume_reading() + return data + if self._exception is not None: + raise self._exception + raise EofStream + + +class WebSocketReader: + def __init__( + self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True + ) -> None: + self.queue = queue + self._max_msg_size = max_msg_size + + self._exc: Optional[Exception] = None + self._partial = bytearray() + self._state = READ_HEADER + + self._opcode: Optional[int] = None + self._frame_fin = False + self._frame_opcode: Optional[int] = None + self._frame_payload: Union[bytes, bytearray] = b"" + self._frame_payload_len = 0 + + self._tail: bytes = b"" + self._has_mask = False + self._frame_mask: Optional[bytes] = None + self._payload_length = 0 + self._payload_length_flag = 0 + self._compressed: Optional[bool] = None + self._decompressobj: Optional[ZLibDecompressor] = None + self._compress = compress + + def feed_eof(self) -> None: + self.queue.feed_eof() + + # data can be bytearray on Windows because proactor event loop uses bytearray + # and asyncio types this to Union[bytes, bytearray, memoryview] so we need + # coerce data to bytes if it is not + def feed_data( + self, data: Union[bytes, bytearray, memoryview] + ) -> Tuple[bool, bytes]: + if type(data) is not bytes: + data = bytes(data) + + if self._exc is not None: + return True, data + + try: + self._feed_data(data) + except Exception as exc: + self._exc = exc + set_exception(self.queue, exc) + return EMPTY_FRAME_ERROR + + return EMPTY_FRAME + + def _feed_data(self, data: bytes) -> None: + msg: WSMessage + for frame in self.parse_frame(data): + fin = frame[0] + opcode = frame[1] + payload = frame[2] + compressed = frame[3] + + is_continuation = opcode == OP_CODE_CONTINUATION + if opcode == OP_CODE_TEXT or opcode == OP_CODE_BINARY or is_continuation: + # load text/binary + if not fin: + # got partial frame payload + if not is_continuation: + self._opcode = opcode + self._partial += payload + if self._max_msg_size and len(self._partial) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(self._partial), self._max_msg_size + ), + ) + continue + + has_partial = bool(self._partial) + if is_continuation: + if self._opcode is None: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Continuation frame for non started message", + ) + opcode = self._opcode + self._opcode = None + # previous frame was non finished + # we should get continuation opcode + elif has_partial: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "The opcode in non-fin frame is expected " + "to be zero, got {!r}".format(opcode), + ) + + assembled_payload: Union[bytes, bytearray] + if has_partial: + assembled_payload = self._partial + payload + self._partial.clear() + else: + assembled_payload = payload + + if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(assembled_payload), self._max_msg_size + ), + ) + + # Decompress process must to be done after all packets + # received. + if compressed: + if not self._decompressobj: + self._decompressobj = ZLibDecompressor( + suppress_deflate_header=True + ) + payload_merged = self._decompressobj.decompress_sync( + assembled_payload + WS_DEFLATE_TRAILING, self._max_msg_size + ) + if self._decompressobj.unconsumed_tail: + left = len(self._decompressobj.unconsumed_tail) + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Decompressed message size {} exceeds limit {}".format( + self._max_msg_size + left, self._max_msg_size + ), + ) + elif type(assembled_payload) is bytes: + payload_merged = assembled_payload + else: + payload_merged = bytes(assembled_payload) + + if opcode == OP_CODE_TEXT: + try: + text = payload_merged.decode("utf-8") + except UnicodeDecodeError as exc: + raise WebSocketError( + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc + + # XXX: The Text and Binary messages here can be a performance + # bottleneck, so we use tuple.__new__ to improve performance. + # This is not type safe, but many tests should fail in + # test_client_ws_functional.py if this is wrong. + self.queue.feed_data( + TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")), + len(payload_merged), + ) + else: + self.queue.feed_data( + TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")), + len(payload_merged), + ) + elif opcode == OP_CODE_CLOSE: + if len(payload) >= 2: + close_code = UNPACK_CLOSE_CODE(payload[:2])[0] + if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + f"Invalid close code: {close_code}", + ) + try: + close_message = payload[2:].decode("utf-8") + except UnicodeDecodeError as exc: + raise WebSocketError( + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc + msg = TUPLE_NEW( + WSMessage, (WSMsgType.CLOSE, close_code, close_message) + ) + elif payload: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + f"Invalid close frame: {fin} {opcode} {payload!r}", + ) + else: + msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, "")) + + self.queue.feed_data(msg, 0) + elif opcode == OP_CODE_PING: + msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, "")) + self.queue.feed_data(msg, len(payload)) + + elif opcode == OP_CODE_PONG: + msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, "")) + self.queue.feed_data(msg, len(payload)) + + else: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" + ) + + def parse_frame( + self, buf: bytes + ) -> List[Tuple[bool, Optional[int], Union[bytes, bytearray], Optional[bool]]]: + """Return the next frame from the socket.""" + frames: List[ + Tuple[bool, Optional[int], Union[bytes, bytearray], Optional[bool]] + ] = [] + if self._tail: + buf, self._tail = self._tail + buf, b"" + + start_pos: int = 0 + buf_length = len(buf) + + while True: + # read header + if self._state == READ_HEADER: + if buf_length - start_pos < 2: + break + first_byte = buf[start_pos] + second_byte = buf[start_pos + 1] + start_pos += 2 + + fin = (first_byte >> 7) & 1 + rsv1 = (first_byte >> 6) & 1 + rsv2 = (first_byte >> 5) & 1 + rsv3 = (first_byte >> 4) & 1 + opcode = first_byte & 0xF + + # frame-fin = %x0 ; more frames of this message follow + # / %x1 ; final frame of this message + # frame-rsv1 = %x0 ; + # 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv2 = %x0 ; + # 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv3 = %x0 ; + # 1 bit, MUST be 0 unless negotiated otherwise + # + # Remove rsv1 from this test for deflate development + if rsv2 or rsv3 or (rsv1 and not self._compress): + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received frame with non-zero reserved bits", + ) + + if opcode > 0x7 and fin == 0: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received fragmented control frame", + ) + + has_mask = (second_byte >> 7) & 1 + length = second_byte & 0x7F + + # Control frames MUST have a payload + # length of 125 bytes or less + if opcode > 0x7 and length > 125: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Control frame payload cannot be larger than 125 bytes", + ) + + # Set compress status if last package is FIN + # OR set compress status if this is first fragment + # Raise error if not first fragment with rsv1 = 0x1 + if self._frame_fin or self._compressed is None: + self._compressed = True if rsv1 else False + elif rsv1: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received frame with non-zero reserved bits", + ) + + self._frame_fin = bool(fin) + self._frame_opcode = opcode + self._has_mask = bool(has_mask) + self._payload_length_flag = length + self._state = READ_PAYLOAD_LENGTH + + # read payload length + if self._state == READ_PAYLOAD_LENGTH: + length_flag = self._payload_length_flag + if length_flag == 126: + if buf_length - start_pos < 2: + break + first_byte = buf[start_pos] + second_byte = buf[start_pos + 1] + start_pos += 2 + self._payload_length = first_byte << 8 | second_byte + elif length_flag > 126: + if buf_length - start_pos < 8: + break + data = buf[start_pos : start_pos + 8] + start_pos += 8 + self._payload_length = UNPACK_LEN3(data)[0] + else: + self._payload_length = length_flag + + self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD + + # read payload mask + if self._state == READ_PAYLOAD_MASK: + if buf_length - start_pos < 4: + break + self._frame_mask = buf[start_pos : start_pos + 4] + start_pos += 4 + self._state = READ_PAYLOAD + + if self._state == READ_PAYLOAD: + chunk_len = buf_length - start_pos + if self._payload_length >= chunk_len: + end_pos = buf_length + self._payload_length -= chunk_len + else: + end_pos = start_pos + self._payload_length + self._payload_length = 0 + + if self._frame_payload_len: + if type(self._frame_payload) is not bytearray: + self._frame_payload = bytearray(self._frame_payload) + self._frame_payload += buf[start_pos:end_pos] + else: + # Fast path for the first frame + self._frame_payload = buf[start_pos:end_pos] + + self._frame_payload_len += end_pos - start_pos + start_pos = end_pos + + if self._payload_length != 0: + break + + if self._has_mask: + assert self._frame_mask is not None + if type(self._frame_payload) is not bytearray: + self._frame_payload = bytearray(self._frame_payload) + websocket_mask(self._frame_mask, self._frame_payload) + + frames.append( + ( + self._frame_fin, + self._frame_opcode, + self._frame_payload, + self._compressed, + ) + ) + self._frame_payload = b"" + self._frame_payload_len = 0 + self._state = READ_HEADER + + self._tail = buf[start_pos:] if start_pos < buf_length else b"" + + return frames diff --git a/.venv/lib/python3.12/site-packages/aiohttp/_websocket/reader_py.py b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/reader_py.py new file mode 100644 index 00000000..1645b394 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/reader_py.py @@ -0,0 +1,469 @@ +"""Reader for WebSocket protocol versions 13 and 8.""" + +import asyncio +import builtins +from collections import deque +from typing import Deque, Final, List, Optional, Set, Tuple, Union + +from ..base_protocol import BaseProtocol +from ..compression_utils import ZLibDecompressor +from ..helpers import _EXC_SENTINEL, set_exception +from ..streams import EofStream +from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask +from .models import ( + WS_DEFLATE_TRAILING, + WebSocketError, + WSCloseCode, + WSMessage, + WSMsgType, +) + +ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode} + +# States for the reader, used to parse the WebSocket frame +# integer values are used so they can be cythonized +READ_HEADER = 1 +READ_PAYLOAD_LENGTH = 2 +READ_PAYLOAD_MASK = 3 +READ_PAYLOAD = 4 + +WS_MSG_TYPE_BINARY = WSMsgType.BINARY +WS_MSG_TYPE_TEXT = WSMsgType.TEXT + +# WSMsgType values unpacked so they can by cythonized to ints +OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value +OP_CODE_TEXT = WSMsgType.TEXT.value +OP_CODE_BINARY = WSMsgType.BINARY.value +OP_CODE_CLOSE = WSMsgType.CLOSE.value +OP_CODE_PING = WSMsgType.PING.value +OP_CODE_PONG = WSMsgType.PONG.value + +EMPTY_FRAME_ERROR = (True, b"") +EMPTY_FRAME = (False, b"") + +TUPLE_NEW = tuple.__new__ + +int_ = int # Prevent Cython from converting to PyInt + + +class WebSocketDataQueue: + """WebSocketDataQueue resumes and pauses an underlying stream. + + It is a destination for WebSocket data. + """ + + def __init__( + self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop + ) -> None: + self._size = 0 + self._protocol = protocol + self._limit = limit * 2 + self._loop = loop + self._eof = False + self._waiter: Optional[asyncio.Future[None]] = None + self._exception: Union[BaseException, None] = None + self._buffer: Deque[Tuple[WSMessage, int]] = deque() + self._get_buffer = self._buffer.popleft + self._put_buffer = self._buffer.append + + def is_eof(self) -> bool: + return self._eof + + def exception(self) -> Optional[BaseException]: + return self._exception + + def set_exception( + self, + exc: "BaseException", + exc_cause: builtins.BaseException = _EXC_SENTINEL, + ) -> None: + self._eof = True + self._exception = exc + if (waiter := self._waiter) is not None: + self._waiter = None + set_exception(waiter, exc, exc_cause) + + def _release_waiter(self) -> None: + if (waiter := self._waiter) is None: + return + self._waiter = None + if not waiter.done(): + waiter.set_result(None) + + def feed_eof(self) -> None: + self._eof = True + self._release_waiter() + self._exception = None # Break cyclic references + + def feed_data(self, data: "WSMessage", size: "int_") -> None: + self._size += size + self._put_buffer((data, size)) + self._release_waiter() + if self._size > self._limit and not self._protocol._reading_paused: + self._protocol.pause_reading() + + async def read(self) -> WSMessage: + if not self._buffer and not self._eof: + assert not self._waiter + self._waiter = self._loop.create_future() + try: + await self._waiter + except (asyncio.CancelledError, asyncio.TimeoutError): + self._waiter = None + raise + return self._read_from_buffer() + + def _read_from_buffer(self) -> WSMessage: + if self._buffer: + data, size = self._get_buffer() + self._size -= size + if self._size < self._limit and self._protocol._reading_paused: + self._protocol.resume_reading() + return data + if self._exception is not None: + raise self._exception + raise EofStream + + +class WebSocketReader: + def __init__( + self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True + ) -> None: + self.queue = queue + self._max_msg_size = max_msg_size + + self._exc: Optional[Exception] = None + self._partial = bytearray() + self._state = READ_HEADER + + self._opcode: Optional[int] = None + self._frame_fin = False + self._frame_opcode: Optional[int] = None + self._frame_payload: Union[bytes, bytearray] = b"" + self._frame_payload_len = 0 + + self._tail: bytes = b"" + self._has_mask = False + self._frame_mask: Optional[bytes] = None + self._payload_length = 0 + self._payload_length_flag = 0 + self._compressed: Optional[bool] = None + self._decompressobj: Optional[ZLibDecompressor] = None + self._compress = compress + + def feed_eof(self) -> None: + self.queue.feed_eof() + + # data can be bytearray on Windows because proactor event loop uses bytearray + # and asyncio types this to Union[bytes, bytearray, memoryview] so we need + # coerce data to bytes if it is not + def feed_data( + self, data: Union[bytes, bytearray, memoryview] + ) -> Tuple[bool, bytes]: + if type(data) is not bytes: + data = bytes(data) + + if self._exc is not None: + return True, data + + try: + self._feed_data(data) + except Exception as exc: + self._exc = exc + set_exception(self.queue, exc) + return EMPTY_FRAME_ERROR + + return EMPTY_FRAME + + def _feed_data(self, data: bytes) -> None: + msg: WSMessage + for frame in self.parse_frame(data): + fin = frame[0] + opcode = frame[1] + payload = frame[2] + compressed = frame[3] + + is_continuation = opcode == OP_CODE_CONTINUATION + if opcode == OP_CODE_TEXT or opcode == OP_CODE_BINARY or is_continuation: + # load text/binary + if not fin: + # got partial frame payload + if not is_continuation: + self._opcode = opcode + self._partial += payload + if self._max_msg_size and len(self._partial) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(self._partial), self._max_msg_size + ), + ) + continue + + has_partial = bool(self._partial) + if is_continuation: + if self._opcode is None: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Continuation frame for non started message", + ) + opcode = self._opcode + self._opcode = None + # previous frame was non finished + # we should get continuation opcode + elif has_partial: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "The opcode in non-fin frame is expected " + "to be zero, got {!r}".format(opcode), + ) + + assembled_payload: Union[bytes, bytearray] + if has_partial: + assembled_payload = self._partial + payload + self._partial.clear() + else: + assembled_payload = payload + + if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(assembled_payload), self._max_msg_size + ), + ) + + # Decompress process must to be done after all packets + # received. + if compressed: + if not self._decompressobj: + self._decompressobj = ZLibDecompressor( + suppress_deflate_header=True + ) + payload_merged = self._decompressobj.decompress_sync( + assembled_payload + WS_DEFLATE_TRAILING, self._max_msg_size + ) + if self._decompressobj.unconsumed_tail: + left = len(self._decompressobj.unconsumed_tail) + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Decompressed message size {} exceeds limit {}".format( + self._max_msg_size + left, self._max_msg_size + ), + ) + elif type(assembled_payload) is bytes: + payload_merged = assembled_payload + else: + payload_merged = bytes(assembled_payload) + + if opcode == OP_CODE_TEXT: + try: + text = payload_merged.decode("utf-8") + except UnicodeDecodeError as exc: + raise WebSocketError( + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc + + # XXX: The Text and Binary messages here can be a performance + # bottleneck, so we use tuple.__new__ to improve performance. + # This is not type safe, but many tests should fail in + # test_client_ws_functional.py if this is wrong. + self.queue.feed_data( + TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")), + len(payload_merged), + ) + else: + self.queue.feed_data( + TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")), + len(payload_merged), + ) + elif opcode == OP_CODE_CLOSE: + if len(payload) >= 2: + close_code = UNPACK_CLOSE_CODE(payload[:2])[0] + if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + f"Invalid close code: {close_code}", + ) + try: + close_message = payload[2:].decode("utf-8") + except UnicodeDecodeError as exc: + raise WebSocketError( + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc + msg = TUPLE_NEW( + WSMessage, (WSMsgType.CLOSE, close_code, close_message) + ) + elif payload: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + f"Invalid close frame: {fin} {opcode} {payload!r}", + ) + else: + msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, "")) + + self.queue.feed_data(msg, 0) + elif opcode == OP_CODE_PING: + msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, "")) + self.queue.feed_data(msg, len(payload)) + + elif opcode == OP_CODE_PONG: + msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, "")) + self.queue.feed_data(msg, len(payload)) + + else: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" + ) + + def parse_frame( + self, buf: bytes + ) -> List[Tuple[bool, Optional[int], Union[bytes, bytearray], Optional[bool]]]: + """Return the next frame from the socket.""" + frames: List[ + Tuple[bool, Optional[int], Union[bytes, bytearray], Optional[bool]] + ] = [] + if self._tail: + buf, self._tail = self._tail + buf, b"" + + start_pos: int = 0 + buf_length = len(buf) + + while True: + # read header + if self._state == READ_HEADER: + if buf_length - start_pos < 2: + break + first_byte = buf[start_pos] + second_byte = buf[start_pos + 1] + start_pos += 2 + + fin = (first_byte >> 7) & 1 + rsv1 = (first_byte >> 6) & 1 + rsv2 = (first_byte >> 5) & 1 + rsv3 = (first_byte >> 4) & 1 + opcode = first_byte & 0xF + + # frame-fin = %x0 ; more frames of this message follow + # / %x1 ; final frame of this message + # frame-rsv1 = %x0 ; + # 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv2 = %x0 ; + # 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv3 = %x0 ; + # 1 bit, MUST be 0 unless negotiated otherwise + # + # Remove rsv1 from this test for deflate development + if rsv2 or rsv3 or (rsv1 and not self._compress): + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received frame with non-zero reserved bits", + ) + + if opcode > 0x7 and fin == 0: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received fragmented control frame", + ) + + has_mask = (second_byte >> 7) & 1 + length = second_byte & 0x7F + + # Control frames MUST have a payload + # length of 125 bytes or less + if opcode > 0x7 and length > 125: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Control frame payload cannot be larger than 125 bytes", + ) + + # Set compress status if last package is FIN + # OR set compress status if this is first fragment + # Raise error if not first fragment with rsv1 = 0x1 + if self._frame_fin or self._compressed is None: + self._compressed = True if rsv1 else False + elif rsv1: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received frame with non-zero reserved bits", + ) + + self._frame_fin = bool(fin) + self._frame_opcode = opcode + self._has_mask = bool(has_mask) + self._payload_length_flag = length + self._state = READ_PAYLOAD_LENGTH + + # read payload length + if self._state == READ_PAYLOAD_LENGTH: + length_flag = self._payload_length_flag + if length_flag == 126: + if buf_length - start_pos < 2: + break + first_byte = buf[start_pos] + second_byte = buf[start_pos + 1] + start_pos += 2 + self._payload_length = first_byte << 8 | second_byte + elif length_flag > 126: + if buf_length - start_pos < 8: + break + data = buf[start_pos : start_pos + 8] + start_pos += 8 + self._payload_length = UNPACK_LEN3(data)[0] + else: + self._payload_length = length_flag + + self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD + + # read payload mask + if self._state == READ_PAYLOAD_MASK: + if buf_length - start_pos < 4: + break + self._frame_mask = buf[start_pos : start_pos + 4] + start_pos += 4 + self._state = READ_PAYLOAD + + if self._state == READ_PAYLOAD: + chunk_len = buf_length - start_pos + if self._payload_length >= chunk_len: + end_pos = buf_length + self._payload_length -= chunk_len + else: + end_pos = start_pos + self._payload_length + self._payload_length = 0 + + if self._frame_payload_len: + if type(self._frame_payload) is not bytearray: + self._frame_payload = bytearray(self._frame_payload) + self._frame_payload += buf[start_pos:end_pos] + else: + # Fast path for the first frame + self._frame_payload = buf[start_pos:end_pos] + + self._frame_payload_len += end_pos - start_pos + start_pos = end_pos + + if self._payload_length != 0: + break + + if self._has_mask: + assert self._frame_mask is not None + if type(self._frame_payload) is not bytearray: + self._frame_payload = bytearray(self._frame_payload) + websocket_mask(self._frame_mask, self._frame_payload) + + frames.append( + ( + self._frame_fin, + self._frame_opcode, + self._frame_payload, + self._compressed, + ) + ) + self._frame_payload = b"" + self._frame_payload_len = 0 + self._state = READ_HEADER + + self._tail = buf[start_pos:] if start_pos < buf_length else b"" + + return frames diff --git a/.venv/lib/python3.12/site-packages/aiohttp/_websocket/writer.py b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/writer.py new file mode 100644 index 00000000..fc2cf32b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/aiohttp/_websocket/writer.py @@ -0,0 +1,177 @@ +"""WebSocket protocol versions 13 and 8.""" + +import asyncio +import random +import zlib +from functools import partial +from typing import Any, Final, Optional, Union + +from ..base_protocol import BaseProtocol +from ..client_exceptions import ClientConnectionResetError +from ..compression_utils import ZLibCompressor +from .helpers import ( + MASK_LEN, + MSG_SIZE, + PACK_CLOSE_CODE, + PACK_LEN1, + PACK_LEN2, + PACK_LEN3, + PACK_RANDBITS, + websocket_mask, +) +from .models import WS_DEFLATE_TRAILING, WSMsgType + +DEFAULT_LIMIT: Final[int] = 2**16 + +# For websockets, keeping latency low is extremely important as implementations +# generally expect to be able to send and receive messages quickly. We use a +# larger chunk size than the default to reduce the number of executor calls +# since the executor is a significant source of latency and overhead when +# the chunks are small. A size of 5KiB was chosen because it is also the +# same value python-zlib-ng choose to use as the threshold to release the GIL. + +WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024 + + +class WebSocketWriter: + """WebSocket writer. + + The writer is responsible for sending messages to the client. It is + created by the protocol when a connection is established. The writer + should avoid implementing any application logic and should only be + concerned with the low-level details of the WebSocket protocol. + """ + + def __init__( + self, + protocol: BaseProtocol, + transport: asyncio.Transport, + *, + use_mask: bool = False, + limit: int = DEFAULT_LIMIT, + random: random.Random = random.Random(), + compress: int = 0, + notakeover: bool = False, + ) -> None: + """Initialize a WebSocket writer.""" + self.protocol = protocol + self.transport = transport + self.use_mask = use_mask + self.get_random_bits = partial(random.getrandbits, 32) + self.compress = compress + self.notakeover = notakeover + self._closing = False + self._limit = limit + self._output_size = 0 + self._compressobj: Any = None # actually compressobj + + async def send_frame( + self, message: bytes, opcode: int, compress: Optional[int] = None + ) -> None: + """Send a frame over the websocket with message as its payload.""" + if self._closing and not (opcode & WSMsgType.CLOSE): + raise ClientConnectionResetError("Cannot write to closing transport") + + # RSV are the reserved bits in the frame header. They are used to + # indicate that the frame is using an extension. + # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 + rsv = 0 + # Only compress larger packets (disabled) + # Does small packet needs to be compressed? + # if self.compress and opcode < 8 and len(message) > 124: + if (compress or self.compress) and opcode < 8: + # RSV1 (rsv = 0x40) is set for compressed frames + # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 + rsv = 0x40 + + if compress: + # Do not set self._compress if compressing is for this frame + compressobj = self._make_compress_obj(compress) + else: # self.compress + if not self._compressobj: + self._compressobj = self._make_compress_obj(self.compress) + compressobj = self._compressobj + + message = ( + await compressobj.compress(message) + + compressobj.flush( + zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH + ) + ).removesuffix(WS_DEFLATE_TRAILING) + # Its critical that we do not return control to the event + # loop until we have finished sending all the compressed + # data. Otherwise we could end up mixing compressed frames + # if there are multiple coroutines compressing data. + + msg_length = len(message) + + use_mask = self.use_mask + mask_bit = 0x80 if use_mask else 0 + + # Depending on the message length, the header is assembled differently. + # The first byte is reserved for the opcode and the RSV bits. + first_byte = 0x80 | rsv | opcode + if msg_length < 126: + header = PACK_LEN1(first_byte, msg_length | mask_bit) + header_len = 2 + elif msg_length < 65536: + header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length) + header_len = 4 + else: + header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length) + header_len = 10 + + if self.transport.is_closing(): + raise ClientConnectionResetError("Cannot write to closing transport") + + # https://datatracker.ietf.org/doc/html/rfc6455#section-5.3 + # If we are using a mask, we need to generate it randomly + # and apply it to the message before sending it. A mask is + # a 32-bit value that is applied to the message using a + # bitwise XOR operation. It is used to prevent certain types + # of attacks on the websocket protocol. The mask is only used + # when aiohttp is acting as a client. Servers do not use a mask. + if use_mask: + mask = PACK_RANDBITS(self.get_random_bits()) + message = bytearray(message) + websocket_mask(mask, message) + self.transport.write(header + mask + message) + self._output_size += MASK_LEN + elif msg_length > MSG_SIZE: + self.transport.write(header) + self.transport.write(message) + else: + self.transport.write(header + message) + + self._output_size += header_len + msg_length + + # It is safe to return control to the event loop when using compression + # after this point as we have already sent or buffered all the data. + + # Once we have written output_size up to the limit, we call the + # drain helper which waits for the transport to be ready to accept + # more data. This is a flow control mechanism to prevent the buffer + # from growing too large. The drain helper will return right away + # if the writer is not paused. + if self._output_size > self._limit: + self._output_size = 0 + if self.protocol._paused: + await self.protocol._drain_helper() + + def _make_compress_obj(self, compress: int) -> ZLibCompressor: + return ZLibCompressor( + level=zlib.Z_BEST_SPEED, + wbits=-compress, + max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, + ) + + async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None: + """Close the websocket, sending the specified code and message.""" + if isinstance(message, str): + message = message.encode("utf-8") + try: + await self.send_frame( + PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE + ) + finally: + self._closing = True |