about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/realtime/_async/client.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/realtime/_async/client.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/realtime/_async/client.py')
-rw-r--r--.venv/lib/python3.12/site-packages/realtime/_async/client.py398
1 files changed, 398 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/realtime/_async/client.py b/.venv/lib/python3.12/site-packages/realtime/_async/client.py
new file mode 100644
index 00000000..9e8b669d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/realtime/_async/client.py
@@ -0,0 +1,398 @@
+import asyncio
+import json
+import logging
+import re
+from base64 import b64decode
+from datetime import datetime
+from functools import wraps
+from math import floor
+from typing import Any, Callable, Dict, List, Optional
+from urllib.parse import urlencode, urlparse, urlunparse
+
+import websockets
+from websockets import connect
+from websockets.client import ClientProtocol
+
+from ..message import Message
+from ..transformers import http_endpoint_url
+from ..types import (
+    DEFAULT_HEARTBEAT_INTERVAL,
+    DEFAULT_TIMEOUT,
+    PHOENIX_CHANNEL,
+    VSN,
+    ChannelEvents,
+)
+from ..utils import is_ws_url
+from .channel import AsyncRealtimeChannel, RealtimeChannelOptions
+
+
+def deprecated(func: Callable) -> Callable:
+    @wraps(func)
+    def wrapper(*args, **kwargs):
+        logger.warning(f"Warning: {func.__name__} is deprecated.")
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
+logger = logging.getLogger(__name__)
+
+
+class AsyncRealtimeClient:
+    def __init__(
+        self,
+        url: str,
+        token: Optional[str] = None,
+        auto_reconnect: bool = True,
+        params: Optional[Dict[str, Any]] = None,
+        hb_interval: int = DEFAULT_HEARTBEAT_INTERVAL,
+        max_retries: int = 5,
+        initial_backoff: float = 1.0,
+        timeout: int = DEFAULT_TIMEOUT,
+    ) -> None:
+        """
+        Initialize a RealtimeClient instance for WebSocket communication.
+
+        :param url: WebSocket URL of the Realtime server. Starts with `ws://` or `wss://`.
+                   Also accepts default Supabase URL: `http://` or `https://`.
+        :param token: Authentication token for the WebSocket connection.
+        :param auto_reconnect: If True, automatically attempt to reconnect on disconnection. Defaults to True.
+        :param params: Optional parameters for the connection. Defaults to None.
+        :param hb_interval: Interval (in seconds) for sending heartbeat messages to keep the connection alive. Defaults to 25.
+        :param max_retries: Maximum number of reconnection attempts. Defaults to 5.
+        :param initial_backoff: Initial backoff time (in seconds) for reconnection attempts. Defaults to 1.0.
+        :param timeout: Connection timeout in seconds. Defaults to DEFAULT_TIMEOUT.
+        """
+        if not is_ws_url(url):
+            ValueError("url must be a valid WebSocket URL or HTTP URL string")
+        self.url = f"{re.sub(r'https://', 'wss://', re.sub(r'http://', 'ws://', url, flags=re.IGNORECASE), flags=re.IGNORECASE)}/websocket"
+        if token:
+            self.url += f"?apikey={token}"
+        self.http_endpoint = http_endpoint_url(url)
+        self.params = params or {}
+        self.apikey = token
+        self.access_token = token
+        self.send_buffer: List[Callable] = []
+        self.hb_interval = hb_interval
+        self.ws_connection: Optional[ClientProtocol] = None
+        self.ref = 0
+        self.auto_reconnect = auto_reconnect
+        self.channels: Dict[str, AsyncRealtimeChannel] = {}
+        self.max_retries = max_retries
+        self.initial_backoff = initial_backoff
+        self.timeout = timeout
+        self._listen_task: Optional[asyncio.Task] = None
+        self._heartbeat_task: Optional[asyncio.Task] = None
+
+    @property
+    def is_connected(self) -> bool:
+        return self.ws_connection is not None
+
+    async def _listen(self) -> None:
+        """
+        An infinite loop that keeps listening.
+        :return: None
+        """
+
+        if not self.ws_connection:
+            raise Exception("WebSocket connection not established")
+
+        try:
+            async for msg in self.ws_connection:
+                logger.info(f"receive: {msg}")
+
+                msg = Message(**json.loads(msg))
+                channel = self.channels.get(msg.topic)
+
+                if channel:
+                    channel._trigger(msg.event, msg.payload, msg.ref)
+        except websockets.exceptions.ConnectionClosedError as e:
+            logger.error(
+                f"WebSocket connection closed with code: {e.code}, reason: {e.reason}"
+            )
+            if self.auto_reconnect:
+                logger.info("Initiating auto-reconnect sequence...")
+
+                await self._reconnect()
+            else:
+                logger.error("Auto-reconnect disabled, terminating connection")
+
+    async def _reconnect(self) -> None:
+        self.ws_connection = None
+        await self.connect()
+
+        if self.is_connected:
+            for topic, channel in self.channels.items():
+                logger.info(f"Rejoining channel after reconnection: {topic}")
+                await channel._rejoin()
+
+    async def connect(self) -> None:
+        """
+        Establishes a WebSocket connection with exponential backoff retry mechanism.
+
+        This method attempts to connect to the WebSocket server. If the connection fails,
+        it will retry with an exponential backoff strategy up to a maximum number of retries.
+
+        Returns:
+            None
+
+        Raises:
+            Exception: If unable to establish a connection after max_retries attempts.
+
+        Note:
+            - The initial backoff time and maximum retries are set during RealtimeClient initialization.
+            - The backoff time doubles after each failed attempt, up to a maximum of 60 seconds.
+        """
+
+        if self.is_connected:
+            logger.info("WebSocket connection already established")
+            return
+
+        retries = 0
+        backoff = self.initial_backoff
+
+        logger.info(f"Attempting to connect to WebSocket at {self.url}")
+
+        while retries < self.max_retries:
+            try:
+                ws = await connect(self.url)
+                self.ws_connection = ws
+                logger.info("WebSocket connection established successfully")
+                return await self._on_connect()
+            except Exception as e:
+                retries += 1
+                logger.error(f"Connection attempt failed: {str(e)}")
+
+                if retries >= self.max_retries or not self.auto_reconnect:
+                    logger.error(
+                        f"Connection failed permanently after {retries} attempts. Error: {str(e)}"
+                    )
+                    raise
+                else:
+                    wait_time = backoff * (2 ** (retries - 1))
+                    logger.info(
+                        f"Retry {retries}/{self.max_retries}: Next attempt in {wait_time:.2f}s (backoff={backoff}s)"
+                    )
+                    await asyncio.sleep(wait_time)
+                    backoff = min(backoff * 2, 60)
+
+        raise Exception(
+            f"Failed to establish WebSocket connection after {self.max_retries} attempts"
+        )
+
+    @deprecated
+    async def listen(self):
+        pass
+
+    async def _on_connect(self) -> None:
+        self._listen_task = asyncio.create_task(self._listen())
+        self._heartbeat_task = asyncio.create_task(self._heartbeat())
+
+        await self._flush_send_buffer()
+
+    async def _flush_send_buffer(self):
+        if self.is_connected and len(self.send_buffer) > 0:
+            for callback in self.send_buffer:
+                await callback()
+            self.send_buffer = []
+
+    async def close(self) -> None:
+        """
+        Close the WebSocket connection.
+
+        Returns:
+            None
+
+        Raises:
+            NotConnectedError: If the connection is not established when this method is called.
+        """
+
+        if self.ws_connection:
+            await self.ws_connection.close()
+
+        self.ws_connection = None
+
+        if self._listen_task:
+            self._listen_task.cancel()
+            self._listen_task = None
+
+        if self._heartbeat_task:
+            self._heartbeat_task.cancel()
+            self._heartbeat_task = None
+
+    async def _heartbeat(self) -> None:
+        if not self.ws_connection:
+            raise Exception("WebSocket connection not established")
+
+        while self.is_connected:
+            try:
+                data = dict(
+                    topic=PHOENIX_CHANNEL,
+                    event=ChannelEvents.heartbeat,
+                    payload={},
+                    ref=None,
+                )
+                await self.send(data)
+                await asyncio.sleep(max(self.hb_interval, 15))
+
+            except websockets.exceptions.ConnectionClosed as e:
+                logger.error(
+                    f"Connection closed during heartbeat. Code: {e.code}, reason: {e.reason}"
+                )
+
+                if self.auto_reconnect:
+                    logger.info("Heartbeat failed - initiating reconnection sequence")
+                    await self._reconnect()
+                else:
+                    logger.error("Heartbeat failed - auto-reconnect disabled")
+                    break
+
+    def channel(
+        self, topic: str, params: Optional[RealtimeChannelOptions] = None
+    ) -> AsyncRealtimeChannel:
+        """
+        Initialize a channel and create a two-way association with the socket.
+
+        :param topic: The topic to subscribe to
+        :param params: Optional channel parameters
+        :return: AsyncRealtimeChannel instance
+        """
+        topic = f"realtime:{topic}"
+        chan = AsyncRealtimeChannel(self, topic, params)
+        self.channels[topic] = chan
+
+        return chan
+
+    def get_channels(self) -> List[AsyncRealtimeChannel]:
+        return list(self.channels.values())
+
+    def _remove_channel(self, channel: AsyncRealtimeChannel) -> None:
+        del self.channels[channel.topic]
+
+    async def remove_channel(self, channel: AsyncRealtimeChannel) -> None:
+        """
+        Unsubscribes and removes a channel from the socket
+        :param channel: Channel to remove
+        :return: None
+        """
+        if channel.topic in self.channels:
+            await self.channels[channel.topic].unsubscribe()
+
+        if len(self.channels) == 0:
+            await self.close()
+
+    async def remove_all_channels(self) -> None:
+        """
+        Unsubscribes and removes all channels from the socket
+        :return: None
+        """
+        for _, channel in self.channels.items():
+            await channel.unsubscribe()
+
+        await self.close()
+
+    def summary(self) -> None:
+        """
+        Prints a list of topics and event the socket is listening to
+        :return: None
+        """
+        for topic, channel in self.channels.items():
+            print(f"Topic: {topic} | Events: {[e for e, _ in channel.listeners]}]")
+
+    async def set_auth(self, token: Optional[str]) -> None:
+        """
+        Set the authentication token for the connection and update all joined channels.
+
+        This method updates the access token for the current connection and sends the new token
+        to all joined channels. This is useful for refreshing authentication or changing users.
+
+        Args:
+            token (Optional[str]): The new authentication token. Can be None to remove authentication.
+
+        Returns:
+            None
+        """
+        # No empty string tokens.
+        if isinstance(token, str) and len(token.strip()) == 0:
+            raise ValueError("Provide a valid jwt token")
+
+        if token:
+            parsed = None
+            try:
+                payload = token.split(".")[1] + "=="
+                parsed = json.loads(b64decode(payload).decode("utf-8"))
+            except Exception:
+                raise ValueError("InvalidJWTToken")
+
+            if parsed:
+                # Handle expired token if any.
+                if "exp" in parsed:
+                    now = floor(datetime.now().timestamp())
+                    valid = now - parsed["exp"] < 0
+                    if not valid:
+                        raise ValueError(
+                            f"InvalidJWTToken: Invalid value for JWT claim 'exp' with value { parsed['exp'] }"
+                        )
+                else:
+                    raise ValueError("InvalidJWTToken: expected claim 'exp'")
+
+        self.access_token = token
+
+        for _, channel in self.channels.items():
+            if channel._joined_once and channel.is_joined:
+                await channel.push(ChannelEvents.access_token, {"access_token": token})
+
+    def _make_ref(self) -> str:
+        self.ref += 1
+        return f"{self.ref}"
+
+    async def send(self, message: Dict[str, Any]) -> None:
+        """
+        Send a message through the WebSocket connection.
+
+        This method serializes the given message dictionary to JSON,
+        and sends it through the WebSocket connection. If the connection
+        is not currently established, the message will be buffered and sent
+        once the connection is re-established.
+
+        Args:
+            message (Dict[str, Any]): The message to be sent, as a dictionary.
+
+        Returns:
+            None
+        """
+
+        message = json.dumps(message)
+        logger.info(f"send: {message}")
+
+        async def send_message():
+            await self.ws_connection.send(message)
+
+        if self.is_connected:
+            await send_message()
+        else:
+            self.send_buffer.append(send_message)
+
+    async def _leave_open_topic(self, topic: str):
+        dup_channels = [
+            ch
+            for ch in self.channels.values()
+            if ch.topic == topic and (ch.is_joined or ch.is_joining)
+        ]
+
+        for ch in dup_channels:
+            await ch.unsubscribe()
+
+    def endpoint_url(self) -> str:
+        parsed_url = urlparse(self.url)
+        query = urlencode({**self.params, "vsn": VSN}, doseq=True)
+        return urlunparse(
+            (
+                parsed_url.scheme,
+                parsed_url.netloc,
+                parsed_url.path,
+                parsed_url.params,
+                query,
+                parsed_url.fragment,
+            )
+        )