from __future__ import annotations
import asyncio
import json
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from realtime.types import (
Binding,
Callback,
ChannelEvents,
ChannelStates,
RealtimeChannelOptions,
RealtimePostgresChangesListenEvent,
RealtimePresenceState,
RealtimeSubscribeStates,
)
from ..transformers import http_endpoint_url
from .presence import (
AsyncRealtimePresence,
PresenceOnJoinCallback,
PresenceOnLeaveCallback,
)
from .push import AsyncPush
from .timer import AsyncTimer
if TYPE_CHECKING:
from .client import AsyncRealtimeClient
logger = logging.getLogger(__name__)
class AsyncRealtimeChannel:
"""
Channel is an abstraction for a topic subscription on an existing socket connection.
Each Channel has its own topic and a list of event-callbacks that respond to messages.
Should only be instantiated through `AsyncRealtimeClient.channel(topic)`.
"""
def __init__(
self,
socket: AsyncRealtimeClient,
topic: str,
params: Optional[RealtimeChannelOptions] = None,
) -> None:
"""
Initialize the Channel object.
:param socket: RealtimeClient object
:param topic: Topic that it subscribes to on the realtime server
:param params: Optional parameters for connection.
"""
self.socket = socket
self.params = params or {}
if self.params.get("config") is None:
self.params["config"] = {
"broadcast": {"ack": False, "self": False},
"presence": {"key": ""},
"private": False,
}
self.topic = topic
self._joined_once = False
self.bindings: Dict[str, List[Binding]] = {}
self.presence = AsyncRealtimePresence(self)
self.state = ChannelStates.CLOSED
self._push_buffer: List[AsyncPush] = []
self.timeout = self.socket.timeout
self.join_push = AsyncPush(self, ChannelEvents.join, self.params)
self.rejoin_timer = AsyncTimer(
self._rejoin_until_connected, lambda tries: 2**tries
)
self.broadcast_endpoint_url = self._broadcast_endpoint_url()
def on_join_push_ok(payload: Dict[str, Any], *args):
self.state = ChannelStates.JOINED
self.rejoin_timer.reset()
for push in self._push_buffer:
asyncio.create_task(push.send())
self._push_buffer = []
def on_join_push_timeout(*args):
if not self.is_joining:
return
logger.error(f"join push timeout for channel {self.topic}")
self.state = ChannelStates.ERRORED
self.rejoin_timer.schedule_timeout()
self.join_push.receive("ok", on_join_push_ok).receive(
"timeout", on_join_push_timeout
)
def on_close(*args):
logger.info(f"channel {self.topic} closed")
self.rejoin_timer.reset()
self.state = ChannelStates.CLOSED
self.socket._remove_channel(self)
def on_error(payload, *args):
if self.is_leaving or self.is_closed:
return
logger.info(f"channel {self.topic} error: {payload}")
self.state = ChannelStates.ERRORED
self.rejoin_timer.schedule_timeout()
self._on("close", on_close)
self._on("error", on_error)
def on_reply(payload, ref):
self._trigger(self._reply_event_name(ref), payload)
self._on(ChannelEvents.reply, on_reply)
# Properties
@property
def is_closed(self):
return self.state == ChannelStates.CLOSED
@property
def is_joining(self):
return self.state == ChannelStates.JOINING
@property
def is_leaving(self):
return self.state == ChannelStates.LEAVING
@property
def is_errored(self):
return self.state == ChannelStates.ERRORED
@property
def is_joined(self):
return self.state == ChannelStates.JOINED
@property
def join_ref(self):
return self.join_push.ref
# Core channel methods
async def subscribe(
self,
callback: Optional[
Callable[[RealtimeSubscribeStates, Optional[Exception]], None]
] = None,
) -> AsyncRealtimeChannel:
"""
Subscribe to the channel. Can only be called once per channel instance.
:param callback: Optional callback function that receives subscription state updates
and any errors that occur during subscription
:return: The Channel instance for method chaining
:raises: Exception if called multiple times on the same channel instance
"""
if not self.socket.is_connected:
await self.socket.connect()
if self._joined_once:
raise Exception(
"Tried to subscribe multiple times. 'subscribe' can only be called a single time per channel instance"
)
else:
config = self.params.get("config", {})
broadcast = config.get("broadcast", {})
presence = config.get("presence", {})
private = config.get("private", False)
access_token_payload = {}
config = {
"broadcast": broadcast,
"presence": presence,
"private": private,
"postgres_changes": list(
map(lambda x: x.filter, self.bindings.get("postgres_changes", []))
),
}
if self.socket.access_token:
access_token_payload["access_token"] = self.socket.access_token
self.join_push.update_payload(
{**{"config": config}, **access_token_payload}
)
self._joined_once = True
def on_join_push_ok(payload: Dict[str, Any]):
server_postgres_changes: List[Dict[str, Any]] = payload.get(
"postgres_changes", []
)
if len(server_postgres_changes) == 0:
callback and callback(RealtimeSubscribeStates.SUBSCRIBED, None)
return
client_postgres_changes = self.bindings.get("postgres_changes", [])
new_postgres_bindings = []
bindings_len = len(client_postgres_changes)
for i in range(bindings_len):
client_binding = client_postgres_changes[i]
event = client_binding.filter.get("event")
schema = client_binding.filter.get("schema")
table = client_binding.filter.get("table")
filter = client_binding.filter.get("filter")
server_binding = (
server_postgres_changes[i]
if i < len(server_postgres_changes)
else None
)
if (
server_binding
and server_binding.get("event") == event
and server_binding.get("schema") == schema
and server_binding.get("table") == table
and server_binding.get("filter") == filter
):
client_binding.id = server_binding.get("id")
new_postgres_bindings.append(client_binding)
else:
asyncio.create_task(self.unsubscribe())
callback and callback(
RealtimeSubscribeStates.CHANNEL_ERROR,
Exception(
"mismatch between server and client bindings for postgres changes"
),
)
return
self.bindings["postgres_changes"] = new_postgres_bindings
callback and callback(RealtimeSubscribeStates.SUBSCRIBED, None)
def on_join_push_error(payload: Dict[str, Any]):
callback and callback(
RealtimeSubscribeStates.CHANNEL_ERROR,
Exception(json.dumps(payload)),
)
def on_join_push_timeout(*args):
callback and callback(RealtimeSubscribeStates.TIMED_OUT, None)
self.join_push.receive("ok", on_join_push_ok).receive(
"error", on_join_push_error
).receive("timeout", on_join_push_timeout)
await self._rejoin()
return self
async def unsubscribe(self):
"""
Unsubscribe from the channel and leave the topic.
Sets channel state to LEAVING and cleans up timers and pushes.
"""
self.state = ChannelStates.LEAVING
self.rejoin_timer.reset()
self.join_push.destroy()
def _close(*args):
logger.info(f"channel {self.topic} leave")
self._trigger(ChannelEvents.close, "leave")
leave_push = AsyncPush(self, ChannelEvents.leave, {})
leave_push.receive("ok", _close).receive("timeout", _close)
await leave_push.send()
if not self._can_push():
leave_push.trigger("ok", {})
async def push(
self, event: str, payload: Dict[str, Any], timeout: Optional[int] = None
) -> AsyncPush:
"""
Push a message to the channel.
:param event: The event name to push
:param payload: The payload to send
:param timeout: Optional timeout in milliseconds
:return: AsyncPush instance representing the push operation
:raises: Exception if called before subscribing to the channel
"""
if not self._joined_once:
raise Exception(
f"tried to push '{event}' to '{self.topic}' before joining. Use channel.subscribe() before pushing events"
)
timeout = timeout or self.timeout
push = AsyncPush(self, event, payload, timeout)
if self._can_push():
await push.send()
else:
push.start_timeout()
self._push_buffer.append(push)
return push
async def join(self) -> AsyncRealtimeChannel:
"""
Coroutine that attempts to join Phoenix Realtime server via a certain topic.
:return: Channel
"""
try:
await self.socket.send(
{
"topic": self.topic,
"event": "phx_join",
"payload": {"config": self.params},
"ref": None,
}
)
except Exception as e:
print(e)
return self
# Event handling methods
def _on(
self, type: str, callback: Callback, filter: Optional[Dict[str, Any]] = None
) -> AsyncRealtimeChannel:
"""
Set up a listener for a specific event.
:param type: The type of the event to listen for.
:param filter: Additional parameters for the event.
:param callback: The callback function to execute when the event is received.
:return: The Channel instance for method chaining.
"""
type_lowercase = type.lower()
binding = Binding(type=type_lowercase, filter=filter or {}, callback=callback)
self.bindings.setdefault(type_lowercase, []).append(binding)
return self
def _off(self, type: str, filter: Dict[str, Any]) -> AsyncRealtimeChannel:
"""
Remove a listener for a specific event type and filter.
:param type: The type of the event to remove the listener for.
:param filter: The filter associated with the listener to remove.
:return: The Channel instance for method chaining.
This method removes all bindings for the specified event type that match
the given filter. If no matching bindings are found, the method does nothing.
"""
type_lowercase = type.lower()
if type_lowercase in self.bindings:
self.bindings[type_lowercase] = [
binding
for binding in self.bindings[type_lowercase]
if binding.filter != filter
]
return self
def on_broadcast(
self, event: str, callback: Callable[[Dict[str, Any]], None]
) -> AsyncRealtimeChannel:
"""
Set up a listener for a specific broadcast event.
:param event: The name of the broadcast event to listen for
:param callback: Function called with the payload when a matching broadcast is received
:return: The Channel instance for method chaining
"""
return self._on(
"broadcast",
filter={"event": event},
callback=lambda payload, _: callback(payload),
)
def on_postgres_changes(
self,
event: RealtimePostgresChangesListenEvent,
callback: Callable[[Dict[str, Any]], None],
table: str = "*",
schema: str = "public",
filter: Optional[str] = None,
) -> AsyncRealtimeChannel:
"""
Set up a listener for Postgres database changes.
:param event: The type of database event to listen for (INSERT, UPDATE, DELETE, or *)
:param callback: Function called with the payload when a matching change is detected
:param table: The table name to monitor. Defaults to "*" for all tables
:param schema: The database schema to monitor. Defaults to "public"
:param filter: Optional filter string to apply
:return: The Channel instance for method chaining
"""
binding_filter = {"event": event, "schema": schema, "table": table}
if filter:
binding_filter["filter"] = filter
return self._on(
"postgres_changes",
filter=binding_filter,
callback=lambda payload, _: callback(payload),
)
def on_system(
self, callback: Callable[[Dict[str, Any], None]]
) -> AsyncRealtimeChannel:
"""
Set up a listener for system events.
:param callback: The callback function to execute when a system event is received.
:return: The Channel instance for method chaining.
"""
return self._on("system", callback=lambda payload, _: callback(payload))
# Presence methods
async def track(self, user_status: Dict[str, Any]) -> None:
"""
Track presence status for the current user.
:param user_status: Dictionary containing the user's presence information
"""
await self.send_presence("track", user_status)
async def untrack(self) -> None:
"""
Stop tracking presence for the current user.
"""
await self.send_presence("untrack", {})
def presence_state(self) -> RealtimePresenceState:
"""
Get the current state of presence on this channel.
:return: Dictionary mapping presence keys to lists of presence payloads
"""
return self.presence.state
def on_presence_sync(self, callback: Callable[[], None]) -> AsyncRealtimeChannel:
"""
Register a callback for presence sync events.
:param callback: The callback function to execute when a presence sync event occurs.
:return: The Channel instance for method chaining.
"""
self.presence.on_sync(callback)
return self
def on_presence_join(
self, callback: PresenceOnJoinCallback
) -> AsyncRealtimeChannel:
"""
Register a callback for presence join events.
:param callback: The callback function to execute when a presence join event occurs.
:return: The Channel instance for method chaining.
"""
self.presence.on_join(callback)
return self
def on_presence_leave(
self, callback: PresenceOnLeaveCallback
) -> AsyncRealtimeChannel:
"""
Register a callback for presence leave events.
:param callback: The callback function to execute when a presence leave event occurs.
:return: The Channel instance for method chaining.
"""
self.presence.on_leave(callback)
return self
# Broadcast methods
async def send_broadcast(self, event: str, data: Any) -> None:
"""
Send a broadcast message through this channel.
:param event: The name of the broadcast event
:param data: The payload to broadcast
"""
await self.push(
ChannelEvents.broadcast,
{"type": "broadcast", "event": event, "payload": data},
)
# Internal methods
def _broadcast_endpoint_url(self):
return f"{http_endpoint_url(self.socket.http_endpoint)}/api/broadcast"
async def _rejoin(self) -> None:
if self.is_leaving:
return
await self.socket._leave_open_topic(self.topic)
self.state = ChannelStates.JOINING
await self.join_push.resend()
def _can_push(self):
return self.socket.is_connected and self._joined_once
async def send_presence(self, event: str, data: Any) -> None:
await self.push(ChannelEvents.presence, {"event": event, "payload": data})
def _trigger(self, type: str, payload: Optional[Any], ref: Optional[str] = None):
type_lowercase = type.lower()
events = [
ChannelEvents.close,
ChannelEvents.error,
ChannelEvents.leave,
ChannelEvents.join,
]
if ref is not None and type_lowercase in events and ref != self.join_push.ref:
return
if type_lowercase in ["insert", "update", "delete"]:
postgres_changes = filter(
lambda binding: binding.filter.get("event", "").lower()
in [type_lowercase, "*"],
self.bindings.get("postgres_changes", []),
)
for binding in postgres_changes:
binding.callback(payload, ref)
else:
bindings = self.bindings.get(type_lowercase, [])
for binding in bindings:
if type_lowercase in ["broadcast", "postgres_changes", "presence"]:
bind_id = binding.id
bind_event = (
binding.filter.get("event", "").lower()
if binding.filter.get("event")
else None
)
payload_event = (
payload.get("event", "").lower()
if payload.get("event")
else None
)
data_type = (
payload.get("data", {}).get("type", "").lower()
if payload.get("data", {}).get("type")
else None
)
if (
bind_id
and bind_id in payload.get("ids", [])
and (bind_event == data_type or bind_event == "*")
):
binding.callback(payload, ref)
elif bind_event in [payload_event, "*"]:
binding.callback(payload, ref)
elif binding.type == type_lowercase:
binding.callback(payload, ref)
def _reply_event_name(self, ref: str):
return f"chan_reply_{ref}"
async def _rejoin_until_connected(self):
self.rejoin_timer.schedule_timeout()
if self.socket.is_connected:
await self._rejoin()