aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/realtime/_async/channel.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/realtime/_async/channel.py')
-rw-r--r--.venv/lib/python3.12/site-packages/realtime/_async/channel.py565
1 files changed, 565 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/realtime/_async/channel.py b/.venv/lib/python3.12/site-packages/realtime/_async/channel.py
new file mode 100644
index 00000000..0050e5e2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/realtime/_async/channel.py
@@ -0,0 +1,565 @@
+from __future__ import annotations
+
+import asyncio
+import json
+import logging
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
+
+from realtime.types import (
+ Binding,
+ Callback,
+ ChannelEvents,
+ ChannelStates,
+ RealtimeChannelOptions,
+ RealtimePostgresChangesListenEvent,
+ RealtimePresenceState,
+ RealtimeSubscribeStates,
+)
+
+from ..transformers import http_endpoint_url
+from .presence import (
+ AsyncRealtimePresence,
+ PresenceOnJoinCallback,
+ PresenceOnLeaveCallback,
+)
+from .push import AsyncPush
+from .timer import AsyncTimer
+
+if TYPE_CHECKING:
+ from .client import AsyncRealtimeClient
+
+logger = logging.getLogger(__name__)
+
+
+class AsyncRealtimeChannel:
+ """
+ Channel is an abstraction for a topic subscription on an existing socket connection.
+ Each Channel has its own topic and a list of event-callbacks that respond to messages.
+ Should only be instantiated through `AsyncRealtimeClient.channel(topic)`.
+ """
+
+ def __init__(
+ self,
+ socket: AsyncRealtimeClient,
+ topic: str,
+ params: Optional[RealtimeChannelOptions] = None,
+ ) -> None:
+ """
+ Initialize the Channel object.
+
+ :param socket: RealtimeClient object
+ :param topic: Topic that it subscribes to on the realtime server
+ :param params: Optional parameters for connection.
+ """
+ self.socket = socket
+ self.params = params or {}
+ if self.params.get("config") is None:
+ self.params["config"] = {
+ "broadcast": {"ack": False, "self": False},
+ "presence": {"key": ""},
+ "private": False,
+ }
+
+ self.topic = topic
+ self._joined_once = False
+ self.bindings: Dict[str, List[Binding]] = {}
+ self.presence = AsyncRealtimePresence(self)
+ self.state = ChannelStates.CLOSED
+ self._push_buffer: List[AsyncPush] = []
+ self.timeout = self.socket.timeout
+
+ self.join_push = AsyncPush(self, ChannelEvents.join, self.params)
+ self.rejoin_timer = AsyncTimer(
+ self._rejoin_until_connected, lambda tries: 2**tries
+ )
+
+ self.broadcast_endpoint_url = self._broadcast_endpoint_url()
+
+ def on_join_push_ok(payload: Dict[str, Any], *args):
+ self.state = ChannelStates.JOINED
+ self.rejoin_timer.reset()
+ for push in self._push_buffer:
+ asyncio.create_task(push.send())
+ self._push_buffer = []
+
+ def on_join_push_timeout(*args):
+ if not self.is_joining:
+ return
+
+ logger.error(f"join push timeout for channel {self.topic}")
+ self.state = ChannelStates.ERRORED
+ self.rejoin_timer.schedule_timeout()
+
+ self.join_push.receive("ok", on_join_push_ok).receive(
+ "timeout", on_join_push_timeout
+ )
+
+ def on_close(*args):
+ logger.info(f"channel {self.topic} closed")
+ self.rejoin_timer.reset()
+ self.state = ChannelStates.CLOSED
+ self.socket._remove_channel(self)
+
+ def on_error(payload, *args):
+ if self.is_leaving or self.is_closed:
+ return
+
+ logger.info(f"channel {self.topic} error: {payload}")
+ self.state = ChannelStates.ERRORED
+ self.rejoin_timer.schedule_timeout()
+
+ self._on("close", on_close)
+ self._on("error", on_error)
+
+ def on_reply(payload, ref):
+ self._trigger(self._reply_event_name(ref), payload)
+
+ self._on(ChannelEvents.reply, on_reply)
+
+ # Properties
+ @property
+ def is_closed(self):
+ return self.state == ChannelStates.CLOSED
+
+ @property
+ def is_joining(self):
+ return self.state == ChannelStates.JOINING
+
+ @property
+ def is_leaving(self):
+ return self.state == ChannelStates.LEAVING
+
+ @property
+ def is_errored(self):
+ return self.state == ChannelStates.ERRORED
+
+ @property
+ def is_joined(self):
+ return self.state == ChannelStates.JOINED
+
+ @property
+ def join_ref(self):
+ return self.join_push.ref
+
+ # Core channel methods
+ async def subscribe(
+ self,
+ callback: Optional[
+ Callable[[RealtimeSubscribeStates, Optional[Exception]], None]
+ ] = None,
+ ) -> AsyncRealtimeChannel:
+ """
+ Subscribe to the channel. Can only be called once per channel instance.
+
+ :param callback: Optional callback function that receives subscription state updates
+ and any errors that occur during subscription
+ :return: The Channel instance for method chaining
+ :raises: Exception if called multiple times on the same channel instance
+ """
+ if not self.socket.is_connected:
+ await self.socket.connect()
+
+ if self._joined_once:
+ raise Exception(
+ "Tried to subscribe multiple times. 'subscribe' can only be called a single time per channel instance"
+ )
+ else:
+ config = self.params.get("config", {})
+ broadcast = config.get("broadcast", {})
+ presence = config.get("presence", {})
+ private = config.get("private", False)
+
+ access_token_payload = {}
+ config = {
+ "broadcast": broadcast,
+ "presence": presence,
+ "private": private,
+ "postgres_changes": list(
+ map(lambda x: x.filter, self.bindings.get("postgres_changes", []))
+ ),
+ }
+
+ if self.socket.access_token:
+ access_token_payload["access_token"] = self.socket.access_token
+
+ self.join_push.update_payload(
+ {**{"config": config}, **access_token_payload}
+ )
+ self._joined_once = True
+
+ def on_join_push_ok(payload: Dict[str, Any]):
+ server_postgres_changes: List[Dict[str, Any]] = payload.get(
+ "postgres_changes", []
+ )
+
+ if len(server_postgres_changes) == 0:
+ callback and callback(RealtimeSubscribeStates.SUBSCRIBED, None)
+ return
+
+ client_postgres_changes = self.bindings.get("postgres_changes", [])
+ new_postgres_bindings = []
+
+ bindings_len = len(client_postgres_changes)
+
+ for i in range(bindings_len):
+ client_binding = client_postgres_changes[i]
+ event = client_binding.filter.get("event")
+ schema = client_binding.filter.get("schema")
+ table = client_binding.filter.get("table")
+ filter = client_binding.filter.get("filter")
+
+ server_binding = (
+ server_postgres_changes[i]
+ if i < len(server_postgres_changes)
+ else None
+ )
+
+ if (
+ server_binding
+ and server_binding.get("event") == event
+ and server_binding.get("schema") == schema
+ and server_binding.get("table") == table
+ and server_binding.get("filter") == filter
+ ):
+ client_binding.id = server_binding.get("id")
+ new_postgres_bindings.append(client_binding)
+ else:
+ asyncio.create_task(self.unsubscribe())
+ callback and callback(
+ RealtimeSubscribeStates.CHANNEL_ERROR,
+ Exception(
+ "mismatch between server and client bindings for postgres changes"
+ ),
+ )
+ return
+
+ self.bindings["postgres_changes"] = new_postgres_bindings
+ callback and callback(RealtimeSubscribeStates.SUBSCRIBED, None)
+
+ def on_join_push_error(payload: Dict[str, Any]):
+ callback and callback(
+ RealtimeSubscribeStates.CHANNEL_ERROR,
+ Exception(json.dumps(payload)),
+ )
+
+ def on_join_push_timeout(*args):
+ callback and callback(RealtimeSubscribeStates.TIMED_OUT, None)
+
+ self.join_push.receive("ok", on_join_push_ok).receive(
+ "error", on_join_push_error
+ ).receive("timeout", on_join_push_timeout)
+
+ await self._rejoin()
+
+ return self
+
+ async def unsubscribe(self):
+ """
+ Unsubscribe from the channel and leave the topic.
+ Sets channel state to LEAVING and cleans up timers and pushes.
+ """
+ self.state = ChannelStates.LEAVING
+
+ self.rejoin_timer.reset()
+ self.join_push.destroy()
+
+ def _close(*args):
+ logger.info(f"channel {self.topic} leave")
+ self._trigger(ChannelEvents.close, "leave")
+
+ leave_push = AsyncPush(self, ChannelEvents.leave, {})
+ leave_push.receive("ok", _close).receive("timeout", _close)
+
+ await leave_push.send()
+
+ if not self._can_push():
+ leave_push.trigger("ok", {})
+
+ async def push(
+ self, event: str, payload: Dict[str, Any], timeout: Optional[int] = None
+ ) -> AsyncPush:
+ """
+ Push a message to the channel.
+
+ :param event: The event name to push
+ :param payload: The payload to send
+ :param timeout: Optional timeout in milliseconds
+ :return: AsyncPush instance representing the push operation
+ :raises: Exception if called before subscribing to the channel
+ """
+ if not self._joined_once:
+ raise Exception(
+ f"tried to push '{event}' to '{self.topic}' before joining. Use channel.subscribe() before pushing events"
+ )
+
+ timeout = timeout or self.timeout
+
+ push = AsyncPush(self, event, payload, timeout)
+ if self._can_push():
+ await push.send()
+ else:
+ push.start_timeout()
+ self._push_buffer.append(push)
+
+ return push
+
+ async def join(self) -> AsyncRealtimeChannel:
+ """
+ Coroutine that attempts to join Phoenix Realtime server via a certain topic.
+
+ :return: Channel
+ """
+ try:
+ await self.socket.send(
+ {
+ "topic": self.topic,
+ "event": "phx_join",
+ "payload": {"config": self.params},
+ "ref": None,
+ }
+ )
+ except Exception as e:
+ print(e)
+ return self
+
+ # Event handling methods
+ def _on(
+ self, type: str, callback: Callback, filter: Optional[Dict[str, Any]] = None
+ ) -> AsyncRealtimeChannel:
+ """
+ Set up a listener for a specific event.
+
+ :param type: The type of the event to listen for.
+ :param filter: Additional parameters for the event.
+ :param callback: The callback function to execute when the event is received.
+ :return: The Channel instance for method chaining.
+ """
+
+ type_lowercase = type.lower()
+ binding = Binding(type=type_lowercase, filter=filter or {}, callback=callback)
+ self.bindings.setdefault(type_lowercase, []).append(binding)
+
+ return self
+
+ def _off(self, type: str, filter: Dict[str, Any]) -> AsyncRealtimeChannel:
+ """
+ Remove a listener for a specific event type and filter.
+
+ :param type: The type of the event to remove the listener for.
+ :param filter: The filter associated with the listener to remove.
+ :return: The Channel instance for method chaining.
+
+ This method removes all bindings for the specified event type that match
+ the given filter. If no matching bindings are found, the method does nothing.
+ """
+ type_lowercase = type.lower()
+
+ if type_lowercase in self.bindings:
+ self.bindings[type_lowercase] = [
+ binding
+ for binding in self.bindings[type_lowercase]
+ if binding.filter != filter
+ ]
+ return self
+
+ def on_broadcast(
+ self, event: str, callback: Callable[[Dict[str, Any]], None]
+ ) -> AsyncRealtimeChannel:
+ """
+ Set up a listener for a specific broadcast event.
+
+ :param event: The name of the broadcast event to listen for
+ :param callback: Function called with the payload when a matching broadcast is received
+ :return: The Channel instance for method chaining
+ """
+ return self._on(
+ "broadcast",
+ filter={"event": event},
+ callback=lambda payload, _: callback(payload),
+ )
+
+ def on_postgres_changes(
+ self,
+ event: RealtimePostgresChangesListenEvent,
+ callback: Callable[[Dict[str, Any]], None],
+ table: str = "*",
+ schema: str = "public",
+ filter: Optional[str] = None,
+ ) -> AsyncRealtimeChannel:
+ """
+ Set up a listener for Postgres database changes.
+
+ :param event: The type of database event to listen for (INSERT, UPDATE, DELETE, or *)
+ :param callback: Function called with the payload when a matching change is detected
+ :param table: The table name to monitor. Defaults to "*" for all tables
+ :param schema: The database schema to monitor. Defaults to "public"
+ :param filter: Optional filter string to apply
+ :return: The Channel instance for method chaining
+ """
+
+ binding_filter = {"event": event, "schema": schema, "table": table}
+ if filter:
+ binding_filter["filter"] = filter
+
+ return self._on(
+ "postgres_changes",
+ filter=binding_filter,
+ callback=lambda payload, _: callback(payload),
+ )
+
+ def on_system(
+ self, callback: Callable[[Dict[str, Any], None]]
+ ) -> AsyncRealtimeChannel:
+ """
+ Set up a listener for system events.
+
+ :param callback: The callback function to execute when a system event is received.
+ :return: The Channel instance for method chaining.
+ """
+ return self._on("system", callback=lambda payload, _: callback(payload))
+
+ # Presence methods
+ async def track(self, user_status: Dict[str, Any]) -> None:
+ """
+ Track presence status for the current user.
+
+ :param user_status: Dictionary containing the user's presence information
+ """
+ await self.send_presence("track", user_status)
+
+ async def untrack(self) -> None:
+ """
+ Stop tracking presence for the current user.
+ """
+ await self.send_presence("untrack", {})
+
+ def presence_state(self) -> RealtimePresenceState:
+ """
+ Get the current state of presence on this channel.
+
+ :return: Dictionary mapping presence keys to lists of presence payloads
+ """
+ return self.presence.state
+
+ def on_presence_sync(self, callback: Callable[[], None]) -> AsyncRealtimeChannel:
+ """
+ Register a callback for presence sync events.
+
+ :param callback: The callback function to execute when a presence sync event occurs.
+ :return: The Channel instance for method chaining.
+ """
+ self.presence.on_sync(callback)
+ return self
+
+ def on_presence_join(
+ self, callback: PresenceOnJoinCallback
+ ) -> AsyncRealtimeChannel:
+ """
+ Register a callback for presence join events.
+
+ :param callback: The callback function to execute when a presence join event occurs.
+ :return: The Channel instance for method chaining.
+ """
+ self.presence.on_join(callback)
+ return self
+
+ def on_presence_leave(
+ self, callback: PresenceOnLeaveCallback
+ ) -> AsyncRealtimeChannel:
+ """
+ Register a callback for presence leave events.
+
+ :param callback: The callback function to execute when a presence leave event occurs.
+ :return: The Channel instance for method chaining.
+ """
+ self.presence.on_leave(callback)
+ return self
+
+ # Broadcast methods
+ async def send_broadcast(self, event: str, data: Any) -> None:
+ """
+ Send a broadcast message through this channel.
+
+ :param event: The name of the broadcast event
+ :param data: The payload to broadcast
+ """
+ await self.push(
+ ChannelEvents.broadcast,
+ {"type": "broadcast", "event": event, "payload": data},
+ )
+
+ # Internal methods
+ def _broadcast_endpoint_url(self):
+ return f"{http_endpoint_url(self.socket.http_endpoint)}/api/broadcast"
+
+ async def _rejoin(self) -> None:
+ if self.is_leaving:
+ return
+ await self.socket._leave_open_topic(self.topic)
+ self.state = ChannelStates.JOINING
+ await self.join_push.resend()
+
+ def _can_push(self):
+ return self.socket.is_connected and self._joined_once
+
+ async def send_presence(self, event: str, data: Any) -> None:
+ await self.push(ChannelEvents.presence, {"event": event, "payload": data})
+
+ def _trigger(self, type: str, payload: Optional[Any], ref: Optional[str] = None):
+ type_lowercase = type.lower()
+ events = [
+ ChannelEvents.close,
+ ChannelEvents.error,
+ ChannelEvents.leave,
+ ChannelEvents.join,
+ ]
+
+ if ref is not None and type_lowercase in events and ref != self.join_push.ref:
+ return
+
+ if type_lowercase in ["insert", "update", "delete"]:
+ postgres_changes = filter(
+ lambda binding: binding.filter.get("event", "").lower()
+ in [type_lowercase, "*"],
+ self.bindings.get("postgres_changes", []),
+ )
+ for binding in postgres_changes:
+ binding.callback(payload, ref)
+ else:
+ bindings = self.bindings.get(type_lowercase, [])
+ for binding in bindings:
+ if type_lowercase in ["broadcast", "postgres_changes", "presence"]:
+ bind_id = binding.id
+ bind_event = (
+ binding.filter.get("event", "").lower()
+ if binding.filter.get("event")
+ else None
+ )
+ payload_event = (
+ payload.get("event", "").lower()
+ if payload.get("event")
+ else None
+ )
+ data_type = (
+ payload.get("data", {}).get("type", "").lower()
+ if payload.get("data", {}).get("type")
+ else None
+ )
+ if (
+ bind_id
+ and bind_id in payload.get("ids", [])
+ and (bind_event == data_type or bind_event == "*")
+ ):
+ binding.callback(payload, ref)
+ elif bind_event in [payload_event, "*"]:
+ binding.callback(payload, ref)
+ elif binding.type == type_lowercase:
+ binding.callback(payload, ref)
+
+ def _reply_event_name(self, ref: str):
+ return f"chan_reply_{ref}"
+
+ async def _rejoin_until_connected(self):
+ self.rejoin_timer.schedule_timeout()
+ if self.socket.is_connected:
+ await self._rejoin()