about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/realtime/_async/channel.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/realtime/_async/channel.py')
-rw-r--r--.venv/lib/python3.12/site-packages/realtime/_async/channel.py565
1 files changed, 565 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/realtime/_async/channel.py b/.venv/lib/python3.12/site-packages/realtime/_async/channel.py
new file mode 100644
index 00000000..0050e5e2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/realtime/_async/channel.py
@@ -0,0 +1,565 @@
+from __future__ import annotations
+
+import asyncio
+import json
+import logging
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
+
+from realtime.types import (
+    Binding,
+    Callback,
+    ChannelEvents,
+    ChannelStates,
+    RealtimeChannelOptions,
+    RealtimePostgresChangesListenEvent,
+    RealtimePresenceState,
+    RealtimeSubscribeStates,
+)
+
+from ..transformers import http_endpoint_url
+from .presence import (
+    AsyncRealtimePresence,
+    PresenceOnJoinCallback,
+    PresenceOnLeaveCallback,
+)
+from .push import AsyncPush
+from .timer import AsyncTimer
+
+if TYPE_CHECKING:
+    from .client import AsyncRealtimeClient
+
+logger = logging.getLogger(__name__)
+
+
+class AsyncRealtimeChannel:
+    """
+    Channel is an abstraction for a topic subscription on an existing socket connection.
+    Each Channel has its own topic and a list of event-callbacks that respond to messages.
+    Should only be instantiated through `AsyncRealtimeClient.channel(topic)`.
+    """
+
+    def __init__(
+        self,
+        socket: AsyncRealtimeClient,
+        topic: str,
+        params: Optional[RealtimeChannelOptions] = None,
+    ) -> None:
+        """
+        Initialize the Channel object.
+
+        :param socket: RealtimeClient object
+        :param topic: Topic that it subscribes to on the realtime server
+        :param params: Optional parameters for connection.
+        """
+        self.socket = socket
+        self.params = params or {}
+        if self.params.get("config") is None:
+            self.params["config"] = {
+                "broadcast": {"ack": False, "self": False},
+                "presence": {"key": ""},
+                "private": False,
+            }
+
+        self.topic = topic
+        self._joined_once = False
+        self.bindings: Dict[str, List[Binding]] = {}
+        self.presence = AsyncRealtimePresence(self)
+        self.state = ChannelStates.CLOSED
+        self._push_buffer: List[AsyncPush] = []
+        self.timeout = self.socket.timeout
+
+        self.join_push = AsyncPush(self, ChannelEvents.join, self.params)
+        self.rejoin_timer = AsyncTimer(
+            self._rejoin_until_connected, lambda tries: 2**tries
+        )
+
+        self.broadcast_endpoint_url = self._broadcast_endpoint_url()
+
+        def on_join_push_ok(payload: Dict[str, Any], *args):
+            self.state = ChannelStates.JOINED
+            self.rejoin_timer.reset()
+            for push in self._push_buffer:
+                asyncio.create_task(push.send())
+            self._push_buffer = []
+
+        def on_join_push_timeout(*args):
+            if not self.is_joining:
+                return
+
+            logger.error(f"join push timeout for channel {self.topic}")
+            self.state = ChannelStates.ERRORED
+            self.rejoin_timer.schedule_timeout()
+
+        self.join_push.receive("ok", on_join_push_ok).receive(
+            "timeout", on_join_push_timeout
+        )
+
+        def on_close(*args):
+            logger.info(f"channel {self.topic} closed")
+            self.rejoin_timer.reset()
+            self.state = ChannelStates.CLOSED
+            self.socket._remove_channel(self)
+
+        def on_error(payload, *args):
+            if self.is_leaving or self.is_closed:
+                return
+
+            logger.info(f"channel {self.topic} error: {payload}")
+            self.state = ChannelStates.ERRORED
+            self.rejoin_timer.schedule_timeout()
+
+        self._on("close", on_close)
+        self._on("error", on_error)
+
+        def on_reply(payload, ref):
+            self._trigger(self._reply_event_name(ref), payload)
+
+        self._on(ChannelEvents.reply, on_reply)
+
+    # Properties
+    @property
+    def is_closed(self):
+        return self.state == ChannelStates.CLOSED
+
+    @property
+    def is_joining(self):
+        return self.state == ChannelStates.JOINING
+
+    @property
+    def is_leaving(self):
+        return self.state == ChannelStates.LEAVING
+
+    @property
+    def is_errored(self):
+        return self.state == ChannelStates.ERRORED
+
+    @property
+    def is_joined(self):
+        return self.state == ChannelStates.JOINED
+
+    @property
+    def join_ref(self):
+        return self.join_push.ref
+
+    # Core channel methods
+    async def subscribe(
+        self,
+        callback: Optional[
+            Callable[[RealtimeSubscribeStates, Optional[Exception]], None]
+        ] = None,
+    ) -> AsyncRealtimeChannel:
+        """
+        Subscribe to the channel. Can only be called once per channel instance.
+
+        :param callback: Optional callback function that receives subscription state updates
+                        and any errors that occur during subscription
+        :return: The Channel instance for method chaining
+        :raises: Exception if called multiple times on the same channel instance
+        """
+        if not self.socket.is_connected:
+            await self.socket.connect()
+
+        if self._joined_once:
+            raise Exception(
+                "Tried to subscribe multiple times. 'subscribe' can only be called a single time per channel instance"
+            )
+        else:
+            config = self.params.get("config", {})
+            broadcast = config.get("broadcast", {})
+            presence = config.get("presence", {})
+            private = config.get("private", False)
+
+            access_token_payload = {}
+            config = {
+                "broadcast": broadcast,
+                "presence": presence,
+                "private": private,
+                "postgres_changes": list(
+                    map(lambda x: x.filter, self.bindings.get("postgres_changes", []))
+                ),
+            }
+
+            if self.socket.access_token:
+                access_token_payload["access_token"] = self.socket.access_token
+
+            self.join_push.update_payload(
+                {**{"config": config}, **access_token_payload}
+            )
+            self._joined_once = True
+
+            def on_join_push_ok(payload: Dict[str, Any]):
+                server_postgres_changes: List[Dict[str, Any]] = payload.get(
+                    "postgres_changes", []
+                )
+
+                if len(server_postgres_changes) == 0:
+                    callback and callback(RealtimeSubscribeStates.SUBSCRIBED, None)
+                    return
+
+                client_postgres_changes = self.bindings.get("postgres_changes", [])
+                new_postgres_bindings = []
+
+                bindings_len = len(client_postgres_changes)
+
+                for i in range(bindings_len):
+                    client_binding = client_postgres_changes[i]
+                    event = client_binding.filter.get("event")
+                    schema = client_binding.filter.get("schema")
+                    table = client_binding.filter.get("table")
+                    filter = client_binding.filter.get("filter")
+
+                    server_binding = (
+                        server_postgres_changes[i]
+                        if i < len(server_postgres_changes)
+                        else None
+                    )
+
+                    if (
+                        server_binding
+                        and server_binding.get("event") == event
+                        and server_binding.get("schema") == schema
+                        and server_binding.get("table") == table
+                        and server_binding.get("filter") == filter
+                    ):
+                        client_binding.id = server_binding.get("id")
+                        new_postgres_bindings.append(client_binding)
+                    else:
+                        asyncio.create_task(self.unsubscribe())
+                        callback and callback(
+                            RealtimeSubscribeStates.CHANNEL_ERROR,
+                            Exception(
+                                "mismatch between server and client bindings for postgres changes"
+                            ),
+                        )
+                        return
+
+                self.bindings["postgres_changes"] = new_postgres_bindings
+                callback and callback(RealtimeSubscribeStates.SUBSCRIBED, None)
+
+            def on_join_push_error(payload: Dict[str, Any]):
+                callback and callback(
+                    RealtimeSubscribeStates.CHANNEL_ERROR,
+                    Exception(json.dumps(payload)),
+                )
+
+            def on_join_push_timeout(*args):
+                callback and callback(RealtimeSubscribeStates.TIMED_OUT, None)
+
+            self.join_push.receive("ok", on_join_push_ok).receive(
+                "error", on_join_push_error
+            ).receive("timeout", on_join_push_timeout)
+
+            await self._rejoin()
+
+        return self
+
+    async def unsubscribe(self):
+        """
+        Unsubscribe from the channel and leave the topic.
+        Sets channel state to LEAVING and cleans up timers and pushes.
+        """
+        self.state = ChannelStates.LEAVING
+
+        self.rejoin_timer.reset()
+        self.join_push.destroy()
+
+        def _close(*args):
+            logger.info(f"channel {self.topic} leave")
+            self._trigger(ChannelEvents.close, "leave")
+
+        leave_push = AsyncPush(self, ChannelEvents.leave, {})
+        leave_push.receive("ok", _close).receive("timeout", _close)
+
+        await leave_push.send()
+
+        if not self._can_push():
+            leave_push.trigger("ok", {})
+
+    async def push(
+        self, event: str, payload: Dict[str, Any], timeout: Optional[int] = None
+    ) -> AsyncPush:
+        """
+        Push a message to the channel.
+
+        :param event: The event name to push
+        :param payload: The payload to send
+        :param timeout: Optional timeout in milliseconds
+        :return: AsyncPush instance representing the push operation
+        :raises: Exception if called before subscribing to the channel
+        """
+        if not self._joined_once:
+            raise Exception(
+                f"tried to push '{event}' to '{self.topic}' before joining. Use channel.subscribe() before pushing events"
+            )
+
+        timeout = timeout or self.timeout
+
+        push = AsyncPush(self, event, payload, timeout)
+        if self._can_push():
+            await push.send()
+        else:
+            push.start_timeout()
+            self._push_buffer.append(push)
+
+        return push
+
+    async def join(self) -> AsyncRealtimeChannel:
+        """
+        Coroutine that attempts to join Phoenix Realtime server via a certain topic.
+
+        :return: Channel
+        """
+        try:
+            await self.socket.send(
+                {
+                    "topic": self.topic,
+                    "event": "phx_join",
+                    "payload": {"config": self.params},
+                    "ref": None,
+                }
+            )
+        except Exception as e:
+            print(e)
+            return self
+
+    # Event handling methods
+    def _on(
+        self, type: str, callback: Callback, filter: Optional[Dict[str, Any]] = None
+    ) -> AsyncRealtimeChannel:
+        """
+        Set up a listener for a specific event.
+
+        :param type: The type of the event to listen for.
+        :param filter: Additional parameters for the event.
+        :param callback: The callback function to execute when the event is received.
+        :return: The Channel instance for method chaining.
+        """
+
+        type_lowercase = type.lower()
+        binding = Binding(type=type_lowercase, filter=filter or {}, callback=callback)
+        self.bindings.setdefault(type_lowercase, []).append(binding)
+
+        return self
+
+    def _off(self, type: str, filter: Dict[str, Any]) -> AsyncRealtimeChannel:
+        """
+        Remove a listener for a specific event type and filter.
+
+        :param type: The type of the event to remove the listener for.
+        :param filter: The filter associated with the listener to remove.
+        :return: The Channel instance for method chaining.
+
+        This method removes all bindings for the specified event type that match
+        the given filter. If no matching bindings are found, the method does nothing.
+        """
+        type_lowercase = type.lower()
+
+        if type_lowercase in self.bindings:
+            self.bindings[type_lowercase] = [
+                binding
+                for binding in self.bindings[type_lowercase]
+                if binding.filter != filter
+            ]
+        return self
+
+    def on_broadcast(
+        self, event: str, callback: Callable[[Dict[str, Any]], None]
+    ) -> AsyncRealtimeChannel:
+        """
+        Set up a listener for a specific broadcast event.
+
+        :param event: The name of the broadcast event to listen for
+        :param callback: Function called with the payload when a matching broadcast is received
+        :return: The Channel instance for method chaining
+        """
+        return self._on(
+            "broadcast",
+            filter={"event": event},
+            callback=lambda payload, _: callback(payload),
+        )
+
+    def on_postgres_changes(
+        self,
+        event: RealtimePostgresChangesListenEvent,
+        callback: Callable[[Dict[str, Any]], None],
+        table: str = "*",
+        schema: str = "public",
+        filter: Optional[str] = None,
+    ) -> AsyncRealtimeChannel:
+        """
+        Set up a listener for Postgres database changes.
+
+        :param event: The type of database event to listen for (INSERT, UPDATE, DELETE, or *)
+        :param callback: Function called with the payload when a matching change is detected
+        :param table: The table name to monitor. Defaults to "*" for all tables
+        :param schema: The database schema to monitor. Defaults to "public"
+        :param filter: Optional filter string to apply
+        :return: The Channel instance for method chaining
+        """
+
+        binding_filter = {"event": event, "schema": schema, "table": table}
+        if filter:
+            binding_filter["filter"] = filter
+
+        return self._on(
+            "postgres_changes",
+            filter=binding_filter,
+            callback=lambda payload, _: callback(payload),
+        )
+
+    def on_system(
+        self, callback: Callable[[Dict[str, Any], None]]
+    ) -> AsyncRealtimeChannel:
+        """
+        Set up a listener for system events.
+
+        :param callback: The callback function to execute when a system event is received.
+        :return: The Channel instance for method chaining.
+        """
+        return self._on("system", callback=lambda payload, _: callback(payload))
+
+    # Presence methods
+    async def track(self, user_status: Dict[str, Any]) -> None:
+        """
+        Track presence status for the current user.
+
+        :param user_status: Dictionary containing the user's presence information
+        """
+        await self.send_presence("track", user_status)
+
+    async def untrack(self) -> None:
+        """
+        Stop tracking presence for the current user.
+        """
+        await self.send_presence("untrack", {})
+
+    def presence_state(self) -> RealtimePresenceState:
+        """
+        Get the current state of presence on this channel.
+
+        :return: Dictionary mapping presence keys to lists of presence payloads
+        """
+        return self.presence.state
+
+    def on_presence_sync(self, callback: Callable[[], None]) -> AsyncRealtimeChannel:
+        """
+        Register a callback for presence sync events.
+
+        :param callback: The callback function to execute when a presence sync event occurs.
+        :return: The Channel instance for method chaining.
+        """
+        self.presence.on_sync(callback)
+        return self
+
+    def on_presence_join(
+        self, callback: PresenceOnJoinCallback
+    ) -> AsyncRealtimeChannel:
+        """
+        Register a callback for presence join events.
+
+        :param callback: The callback function to execute when a presence join event occurs.
+        :return: The Channel instance for method chaining.
+        """
+        self.presence.on_join(callback)
+        return self
+
+    def on_presence_leave(
+        self, callback: PresenceOnLeaveCallback
+    ) -> AsyncRealtimeChannel:
+        """
+        Register a callback for presence leave events.
+
+        :param callback: The callback function to execute when a presence leave event occurs.
+        :return: The Channel instance for method chaining.
+        """
+        self.presence.on_leave(callback)
+        return self
+
+    # Broadcast methods
+    async def send_broadcast(self, event: str, data: Any) -> None:
+        """
+        Send a broadcast message through this channel.
+
+        :param event: The name of the broadcast event
+        :param data: The payload to broadcast
+        """
+        await self.push(
+            ChannelEvents.broadcast,
+            {"type": "broadcast", "event": event, "payload": data},
+        )
+
+    # Internal methods
+    def _broadcast_endpoint_url(self):
+        return f"{http_endpoint_url(self.socket.http_endpoint)}/api/broadcast"
+
+    async def _rejoin(self) -> None:
+        if self.is_leaving:
+            return
+        await self.socket._leave_open_topic(self.topic)
+        self.state = ChannelStates.JOINING
+        await self.join_push.resend()
+
+    def _can_push(self):
+        return self.socket.is_connected and self._joined_once
+
+    async def send_presence(self, event: str, data: Any) -> None:
+        await self.push(ChannelEvents.presence, {"event": event, "payload": data})
+
+    def _trigger(self, type: str, payload: Optional[Any], ref: Optional[str] = None):
+        type_lowercase = type.lower()
+        events = [
+            ChannelEvents.close,
+            ChannelEvents.error,
+            ChannelEvents.leave,
+            ChannelEvents.join,
+        ]
+
+        if ref is not None and type_lowercase in events and ref != self.join_push.ref:
+            return
+
+        if type_lowercase in ["insert", "update", "delete"]:
+            postgres_changes = filter(
+                lambda binding: binding.filter.get("event", "").lower()
+                in [type_lowercase, "*"],
+                self.bindings.get("postgres_changes", []),
+            )
+            for binding in postgres_changes:
+                binding.callback(payload, ref)
+        else:
+            bindings = self.bindings.get(type_lowercase, [])
+            for binding in bindings:
+                if type_lowercase in ["broadcast", "postgres_changes", "presence"]:
+                    bind_id = binding.id
+                    bind_event = (
+                        binding.filter.get("event", "").lower()
+                        if binding.filter.get("event")
+                        else None
+                    )
+                    payload_event = (
+                        payload.get("event", "").lower()
+                        if payload.get("event")
+                        else None
+                    )
+                    data_type = (
+                        payload.get("data", {}).get("type", "").lower()
+                        if payload.get("data", {}).get("type")
+                        else None
+                    )
+                    if (
+                        bind_id
+                        and bind_id in payload.get("ids", [])
+                        and (bind_event == data_type or bind_event == "*")
+                    ):
+                        binding.callback(payload, ref)
+                    elif bind_event in [payload_event, "*"]:
+                        binding.callback(payload, ref)
+                elif binding.type == type_lowercase:
+                    binding.callback(payload, ref)
+
+    def _reply_event_name(self, ref: str):
+        return f"chan_reply_{ref}"
+
+    async def _rejoin_until_connected(self):
+        self.rejoin_timer.schedule_timeout()
+        if self.socket.is_connected:
+            await self._rejoin()