aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/websockets/asyncio/client.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/websockets/asyncio/client.py')
-rw-r--r--.venv/lib/python3.12/site-packages/websockets/asyncio/client.py567
1 files changed, 567 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/websockets/asyncio/client.py b/.venv/lib/python3.12/site-packages/websockets/asyncio/client.py
new file mode 100644
index 00000000..f05f546d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/websockets/asyncio/client.py
@@ -0,0 +1,567 @@
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+import traceback
+import urllib.parse
+from collections.abc import AsyncIterator, Generator, Sequence
+from types import TracebackType
+from typing import Any, Callable
+
+from ..client import ClientProtocol, backoff
+from ..datastructures import HeadersLike
+from ..exceptions import InvalidMessage, InvalidStatus, SecurityError
+from ..extensions.base import ClientExtensionFactory
+from ..extensions.permessage_deflate import enable_client_permessage_deflate
+from ..headers import validate_subprotocols
+from ..http11 import USER_AGENT, Response
+from ..protocol import CONNECTING, Event
+from ..typing import LoggerLike, Origin, Subprotocol
+from ..uri import WebSocketURI, parse_uri
+from .compatibility import TimeoutError, asyncio_timeout
+from .connection import Connection
+
+
+__all__ = ["connect", "unix_connect", "ClientConnection"]
+
+MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))
+
+
+class ClientConnection(Connection):
+ """
+ :mod:`asyncio` implementation of a WebSocket client connection.
+
+ :class:`ClientConnection` provides :meth:`recv` and :meth:`send` coroutines
+ for receiving and sending messages.
+
+ It supports asynchronous iteration to receive messages::
+
+ async for message in websocket:
+ await process(message)
+
+ The iterator exits normally when the connection is closed with close code
+ 1000 (OK) or 1001 (going away) or without a close code. It raises a
+ :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
+ closed with any other code.
+
+ The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``,
+ and ``write_limit`` arguments have the same meaning as in :func:`connect`.
+
+ Args:
+ protocol: Sans-I/O connection.
+
+ """
+
+ def __init__(
+ self,
+ protocol: ClientProtocol,
+ *,
+ ping_interval: float | None = 20,
+ ping_timeout: float | None = 20,
+ close_timeout: float | None = 10,
+ max_queue: int | None | tuple[int | None, int | None] = 16,
+ write_limit: int | tuple[int, int | None] = 2**15,
+ ) -> None:
+ self.protocol: ClientProtocol
+ super().__init__(
+ protocol,
+ ping_interval=ping_interval,
+ ping_timeout=ping_timeout,
+ close_timeout=close_timeout,
+ max_queue=max_queue,
+ write_limit=write_limit,
+ )
+ self.response_rcvd: asyncio.Future[None] = self.loop.create_future()
+
+ async def handshake(
+ self,
+ additional_headers: HeadersLike | None = None,
+ user_agent_header: str | None = USER_AGENT,
+ ) -> None:
+ """
+ Perform the opening handshake.
+
+ """
+ async with self.send_context(expected_state=CONNECTING):
+ self.request = self.protocol.connect()
+ if additional_headers is not None:
+ self.request.headers.update(additional_headers)
+ if user_agent_header:
+ self.request.headers["User-Agent"] = user_agent_header
+ self.protocol.send_request(self.request)
+
+ await asyncio.wait(
+ [self.response_rcvd, self.connection_lost_waiter],
+ return_when=asyncio.FIRST_COMPLETED,
+ )
+
+ # self.protocol.handshake_exc is set when the connection is lost before
+ # receiving a response, when the response cannot be parsed, or when the
+ # response fails the handshake.
+
+ if self.protocol.handshake_exc is not None:
+ raise self.protocol.handshake_exc
+
+ def process_event(self, event: Event) -> None:
+ """
+ Process one incoming event.
+
+ """
+ # First event - handshake response.
+ if self.response is None:
+ assert isinstance(event, Response)
+ self.response = event
+ self.response_rcvd.set_result(None)
+ # Later events - frames.
+ else:
+ super().process_event(event)
+
+
+def process_exception(exc: Exception) -> Exception | None:
+ """
+ Determine whether a connection error is retryable or fatal.
+
+ When reconnecting automatically with ``async for ... in connect(...)``, if a
+ connection attempt fails, :func:`process_exception` is called to determine
+ whether to retry connecting or to raise the exception.
+
+ This function defines the default behavior, which is to retry on:
+
+ * :exc:`EOFError`, :exc:`OSError`, :exc:`asyncio.TimeoutError`: network
+ errors;
+ * :exc:`~websockets.exceptions.InvalidStatus` when the status code is 500,
+ 502, 503, or 504: server or proxy errors.
+
+ All other exceptions are considered fatal.
+
+ You can change this behavior with the ``process_exception`` argument of
+ :func:`connect`.
+
+ Return :obj:`None` if the exception is retryable i.e. when the error could
+ be transient and trying to reconnect with the same parameters could succeed.
+ The exception will be logged at the ``INFO`` level.
+
+ Return an exception, either ``exc`` or a new exception, if the exception is
+ fatal i.e. when trying to reconnect will most likely produce the same error.
+ That exception will be raised, breaking out of the retry loop.
+
+ """
+ if isinstance(exc, (OSError, asyncio.TimeoutError)):
+ return None
+ if isinstance(exc, InvalidMessage) and isinstance(exc.__cause__, EOFError):
+ return None
+ if isinstance(exc, InvalidStatus) and exc.response.status_code in [
+ 500, # Internal Server Error
+ 502, # Bad Gateway
+ 503, # Service Unavailable
+ 504, # Gateway Timeout
+ ]:
+ return None
+ return exc
+
+
+# This is spelled in lower case because it's exposed as a callable in the API.
+class connect:
+ """
+ Connect to the WebSocket server at ``uri``.
+
+ This coroutine returns a :class:`ClientConnection` instance, which you can
+ use to send and receive messages.
+
+ :func:`connect` may be used as an asynchronous context manager::
+
+ from websockets.asyncio.client import connect
+
+ async with connect(...) as websocket:
+ ...
+
+ The connection is closed automatically when exiting the context.
+
+ :func:`connect` can be used as an infinite asynchronous iterator to
+ reconnect automatically on errors::
+
+ async for websocket in connect(...):
+ try:
+ ...
+ except websockets.exceptions.ConnectionClosed:
+ continue
+
+ If the connection fails with a transient error, it is retried with
+ exponential backoff. If it fails with a fatal error, the exception is
+ raised, breaking out of the loop.
+
+ The connection is closed automatically after each iteration of the loop.
+
+ Args:
+ uri: URI of the WebSocket server.
+ origin: Value of the ``Origin`` header, for servers that require it.
+ extensions: List of supported extensions, in order in which they
+ should be negotiated and run.
+ subprotocols: List of supported subprotocols, in order of decreasing
+ preference.
+ additional_headers (HeadersLike | None): Arbitrary HTTP headers to add
+ to the handshake request.
+ user_agent_header: Value of the ``User-Agent`` request header.
+ It defaults to ``"Python/x.y.z websockets/X.Y"``.
+ Setting it to :obj:`None` removes the header.
+ compression: The "permessage-deflate" extension is enabled by default.
+ Set ``compression`` to :obj:`None` to disable it. See the
+ :doc:`compression guide <../../topics/compression>` for details.
+ process_exception: When reconnecting automatically, tell whether an
+ error is transient or fatal. The default behavior is defined by
+ :func:`process_exception`. Refer to its documentation for details.
+ open_timeout: Timeout for opening the connection in seconds.
+ :obj:`None` disables the timeout.
+ ping_interval: Interval between keepalive pings in seconds.
+ :obj:`None` disables keepalive.
+ ping_timeout: Timeout for keepalive pings in seconds.
+ :obj:`None` disables timeouts.
+ close_timeout: Timeout for closing the connection in seconds.
+ :obj:`None` disables the timeout.
+ max_size: Maximum size of incoming messages in bytes.
+ :obj:`None` disables the limit.
+ max_queue: High-water mark of the buffer where frames are received.
+ It defaults to 16 frames. The low-water mark defaults to ``max_queue
+ // 4``. You may pass a ``(high, low)`` tuple to set the high-water
+ and low-water marks. If you want to disable flow control entirely,
+ you may set it to ``None``, although that's a bad idea.
+ write_limit: High-water mark of write buffer in bytes. It is passed to
+ :meth:`~asyncio.WriteTransport.set_write_buffer_limits`. It defaults
+ to 32 KiB. You may pass a ``(high, low)`` tuple to set the
+ high-water and low-water marks.
+ logger: Logger for this client.
+ It defaults to ``logging.getLogger("websockets.client")``.
+ See the :doc:`logging guide <../../topics/logging>` for details.
+ create_connection: Factory for the :class:`ClientConnection` managing
+ the connection. Set it to a wrapper or a subclass to customize
+ connection handling.
+
+ Any other keyword arguments are passed to the event loop's
+ :meth:`~asyncio.loop.create_connection` method.
+
+ For example:
+
+ * You can set ``ssl`` to a :class:`~ssl.SSLContext` to enforce TLS settings.
+ When connecting to a ``wss://`` URI, if ``ssl`` isn't provided, a TLS
+ context is created with :func:`~ssl.create_default_context`.
+
+ * You can set ``server_hostname`` to override the host name from ``uri`` in
+ the TLS handshake.
+
+ * You can set ``host`` and ``port`` to connect to a different host and port
+ from those found in ``uri``. This only changes the destination of the TCP
+ connection. The host name from ``uri`` is still used in the TLS handshake
+ for secure connections and in the ``Host`` header.
+
+ * You can set ``sock`` to provide a preexisting TCP socket. You may call
+ :func:`socket.create_connection` (not to be confused with the event loop's
+ :meth:`~asyncio.loop.create_connection` method) to create a suitable
+ client socket and customize it.
+
+ Raises:
+ InvalidURI: If ``uri`` isn't a valid WebSocket URI.
+ OSError: If the TCP connection fails.
+ InvalidHandshake: If the opening handshake fails.
+ TimeoutError: If the opening handshake times out.
+
+ """
+
+ def __init__(
+ self,
+ uri: str,
+ *,
+ # WebSocket
+ origin: Origin | None = None,
+ extensions: Sequence[ClientExtensionFactory] | None = None,
+ subprotocols: Sequence[Subprotocol] | None = None,
+ additional_headers: HeadersLike | None = None,
+ user_agent_header: str | None = USER_AGENT,
+ compression: str | None = "deflate",
+ process_exception: Callable[[Exception], Exception | None] = process_exception,
+ # Timeouts
+ open_timeout: float | None = 10,
+ ping_interval: float | None = 20,
+ ping_timeout: float | None = 20,
+ close_timeout: float | None = 10,
+ # Limits
+ max_size: int | None = 2**20,
+ max_queue: int | None | tuple[int | None, int | None] = 16,
+ write_limit: int | tuple[int, int | None] = 2**15,
+ # Logging
+ logger: LoggerLike | None = None,
+ # Escape hatch for advanced customization
+ create_connection: type[ClientConnection] | None = None,
+ # Other keyword arguments are passed to loop.create_connection
+ **kwargs: Any,
+ ) -> None:
+ self.uri = uri
+
+ if subprotocols is not None:
+ validate_subprotocols(subprotocols)
+
+ if compression == "deflate":
+ extensions = enable_client_permessage_deflate(extensions)
+ elif compression is not None:
+ raise ValueError(f"unsupported compression: {compression}")
+
+ if logger is None:
+ logger = logging.getLogger("websockets.client")
+
+ if create_connection is None:
+ create_connection = ClientConnection
+
+ def protocol_factory(wsuri: WebSocketURI) -> ClientConnection:
+ # This is a protocol in the Sans-I/O implementation of websockets.
+ protocol = ClientProtocol(
+ wsuri,
+ origin=origin,
+ extensions=extensions,
+ subprotocols=subprotocols,
+ max_size=max_size,
+ logger=logger,
+ )
+ # This is a connection in websockets and a protocol in asyncio.
+ connection = create_connection(
+ protocol,
+ ping_interval=ping_interval,
+ ping_timeout=ping_timeout,
+ close_timeout=close_timeout,
+ max_queue=max_queue,
+ write_limit=write_limit,
+ )
+ return connection
+
+ self.protocol_factory = protocol_factory
+ self.handshake_args = (
+ additional_headers,
+ user_agent_header,
+ )
+ self.process_exception = process_exception
+ self.open_timeout = open_timeout
+ self.logger = logger
+ self.connection_kwargs = kwargs
+
+ async def create_connection(self) -> ClientConnection:
+ """Create TCP or Unix connection."""
+ loop = asyncio.get_running_loop()
+
+ wsuri = parse_uri(self.uri)
+ kwargs = self.connection_kwargs.copy()
+
+ def factory() -> ClientConnection:
+ return self.protocol_factory(wsuri)
+
+ if wsuri.secure:
+ kwargs.setdefault("ssl", True)
+ kwargs.setdefault("server_hostname", wsuri.host)
+ if kwargs.get("ssl") is None:
+ raise ValueError("ssl=None is incompatible with a wss:// URI")
+ else:
+ if kwargs.get("ssl") is not None:
+ raise ValueError("ssl argument is incompatible with a ws:// URI")
+
+ if kwargs.pop("unix", False):
+ _, connection = await loop.create_unix_connection(factory, **kwargs)
+ else:
+ if kwargs.get("sock") is None:
+ kwargs.setdefault("host", wsuri.host)
+ kwargs.setdefault("port", wsuri.port)
+ _, connection = await loop.create_connection(factory, **kwargs)
+ return connection
+
+ def process_redirect(self, exc: Exception) -> Exception | str:
+ """
+ Determine whether a connection error is a redirect that can be followed.
+
+ Return the new URI if it's a valid redirect. Else, return an exception.
+
+ """
+ if not (
+ isinstance(exc, InvalidStatus)
+ and exc.response.status_code
+ in [
+ 300, # Multiple Choices
+ 301, # Moved Permanently
+ 302, # Found
+ 303, # See Other
+ 307, # Temporary Redirect
+ 308, # Permanent Redirect
+ ]
+ and "Location" in exc.response.headers
+ ):
+ return exc
+
+ old_wsuri = parse_uri(self.uri)
+ new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"])
+ new_wsuri = parse_uri(new_uri)
+
+ # If connect() received a socket, it is closed and cannot be reused.
+ if self.connection_kwargs.get("sock") is not None:
+ return ValueError(
+ f"cannot follow redirect to {new_uri} with a preexisting socket"
+ )
+
+ # TLS downgrade is forbidden.
+ if old_wsuri.secure and not new_wsuri.secure:
+ return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}")
+
+ # Apply restrictions to cross-origin redirects.
+ if (
+ old_wsuri.secure != new_wsuri.secure
+ or old_wsuri.host != new_wsuri.host
+ or old_wsuri.port != new_wsuri.port
+ ):
+ # Cross-origin redirects on Unix sockets don't quite make sense.
+ if self.connection_kwargs.get("unix", False):
+ return ValueError(
+ f"cannot follow cross-origin redirect to {new_uri} "
+ f"with a Unix socket"
+ )
+
+ # Cross-origin redirects when host and port are overridden are ill-defined.
+ if (
+ self.connection_kwargs.get("host") is not None
+ or self.connection_kwargs.get("port") is not None
+ ):
+ return ValueError(
+ f"cannot follow cross-origin redirect to {new_uri} "
+ f"with an explicit host or port"
+ )
+
+ return new_uri
+
+ # ... = await connect(...)
+
+ def __await__(self) -> Generator[Any, None, ClientConnection]:
+ # Create a suitable iterator by calling __await__ on a coroutine.
+ return self.__await_impl__().__await__()
+
+ async def __await_impl__(self) -> ClientConnection:
+ try:
+ async with asyncio_timeout(self.open_timeout):
+ for _ in range(MAX_REDIRECTS):
+ self.connection = await self.create_connection()
+ try:
+ await self.connection.handshake(*self.handshake_args)
+ except asyncio.CancelledError:
+ self.connection.transport.abort()
+ raise
+ except Exception as exc:
+ # Always close the connection even though keep-alive is
+ # the default in HTTP/1.1 because create_connection ties
+ # opening the network connection with initializing the
+ # protocol. In the current design of connect(), there is
+ # no easy way to reuse the network connection that works
+ # in every case nor to reinitialize the protocol.
+ self.connection.transport.abort()
+
+ uri_or_exc = self.process_redirect(exc)
+ # Response is a valid redirect; follow it.
+ if isinstance(uri_or_exc, str):
+ self.uri = uri_or_exc
+ continue
+ # Response isn't a valid redirect; raise the exception.
+ if uri_or_exc is exc:
+ raise
+ else:
+ raise uri_or_exc from exc
+
+ else:
+ self.connection.start_keepalive()
+ return self.connection
+ else:
+ raise SecurityError(f"more than {MAX_REDIRECTS} redirects")
+
+ except TimeoutError:
+ # Re-raise exception with an informative error message.
+ raise TimeoutError("timed out during handshake") from None
+
+ # ... = yield from connect(...) - remove when dropping Python < 3.10
+
+ __iter__ = __await__
+
+ # async with connect(...) as ...: ...
+
+ async def __aenter__(self) -> ClientConnection:
+ return await self
+
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_value: BaseException | None,
+ traceback: TracebackType | None,
+ ) -> None:
+ await self.connection.close()
+
+ # async for ... in connect(...):
+
+ async def __aiter__(self) -> AsyncIterator[ClientConnection]:
+ delays: Generator[float] | None = None
+ while True:
+ try:
+ async with self as protocol:
+ yield protocol
+ except Exception as exc:
+ # Determine whether the exception is retryable or fatal.
+ # The API of process_exception is "return an exception or None";
+ # "raise an exception" is also supported because it's a frequent
+ # mistake. It isn't documented in order to keep the API simple.
+ try:
+ new_exc = self.process_exception(exc)
+ except Exception as raised_exc:
+ new_exc = raised_exc
+
+ # The connection failed with a fatal error.
+ # Raise the exception and exit the loop.
+ if new_exc is exc:
+ raise
+ if new_exc is not None:
+ raise new_exc from exc
+
+ # The connection failed with a retryable error.
+ # Start or continue backoff and reconnect.
+ if delays is None:
+ delays = backoff()
+ delay = next(delays)
+ self.logger.info(
+ "connect failed; reconnecting in %.1f seconds: %s",
+ delay,
+ # Remove first argument when dropping Python 3.9.
+ traceback.format_exception_only(type(exc), exc)[0].strip(),
+ )
+ await asyncio.sleep(delay)
+ continue
+
+ else:
+ # The connection succeeded. Reset backoff.
+ delays = None
+
+
+def unix_connect(
+ path: str | None = None,
+ uri: str | None = None,
+ **kwargs: Any,
+) -> connect:
+ """
+ Connect to a WebSocket server listening on a Unix socket.
+
+ This function accepts the same keyword arguments as :func:`connect`.
+
+ It's only available on Unix.
+
+ It's mainly useful for debugging servers listening on Unix sockets.
+
+ Args:
+ path: File system path to the Unix socket.
+ uri: URI of the WebSocket server. ``uri`` defaults to
+ ``ws://localhost/`` or, when a ``ssl`` argument is provided, to
+ ``wss://localhost/``.
+
+ """
+ if uri is None:
+ if kwargs.get("ssl") is None:
+ uri = "ws://localhost/"
+ else:
+ uri = "wss://localhost/"
+ return connect(uri=uri, unix=True, path=path, **kwargs)