diff --git a/src/ypywidgets/comm.py b/src/ypywidgets/comm.py index 10f004c..b2d1d91 100644 --- a/src/ypywidgets/comm.py +++ b/src/ypywidgets/comm.py @@ -2,6 +2,7 @@ import comm from pycrdt import ( + Awareness, Doc, Text, TransactionEvent, @@ -10,6 +11,7 @@ create_sync_message, create_update_message, handle_sync_message, + read_message, ) from .widget import Widget @@ -48,10 +50,15 @@ def __init__( ) -> None: self._ydoc = ydoc self._comm = comm + self._awareness = Awareness(ydoc) msg = create_sync_message(ydoc) self._comm.send(buffers=[msg]) self._comm.on_msg(self._receive) + @property + def awareness(self) -> Awareness: + return self._awareness + def _receive(self, msg): message = bytes(msg["buffers"][0]) match message[0]: @@ -61,6 +68,10 @@ def _receive(self, msg): self._comm.send(buffers=[reply]) if message[1] == YSyncMessageType.SYNC_STEP2: self._ydoc.observe(self._send) + case YMessageType.AWARENESS: + # Same as pycrdt.websocket.yroom: strip Y message kind, decode body. + update = read_message(message[1:]) + self._awareness.apply_awareness_update(update, None) def _send(self, event: TransactionEvent): update = event.update @@ -86,7 +97,11 @@ def __init__( create_ydoc=not ydoc, ) self._comm = create_widget_comm(comm_data, comm_metadata, comm_id) - CommProvider(self.ydoc, self._comm) + self._comm_provider = CommProvider(self.ydoc, self._comm) + + @property + def awareness(self) -> Awareness: + return self._comm_provider.awareness def _repr_mimebundle_(self, *args, **kwargs): # pragma: nocover plaintext = repr(self) diff --git a/tests/test_comm_awareness.py b/tests/test_comm_awareness.py new file mode 100644 index 0000000..5f9ca0f --- /dev/null +++ b/tests/test_comm_awareness.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import pytest +from pycrdt import Awareness, Doc, YMessageType, create_awareness_message +from ypywidgets.comm import CommWidget + +pytestmark = pytest.mark.anyio + + +async def test_comm_provider_applies_awareness_message(synced_widgets, context): + async with context: + local_widget = await synced_widgets.get_local_widget() + remote_awareness = Awareness(Doc()) + remote_awareness.set_local_state({"role": "remote"}) + payload = remote_awareness.encode_awareness_update([remote_awareness.client_id]) + message = create_awareness_message(payload) + + assert message[0] == YMessageType.AWARENESS + + local_widget._comm_provider._receive({"buffers": [message]}) + + remote_state = local_widget.awareness.states.get(remote_awareness.client_id) + assert remote_state is not None + assert remote_state.get("role") == "remote" + + +async def test_comm_widget_exposes_provider_awareness(): + widget = CommWidget() + assert widget.awareness is widget._comm_provider.awareness + + +async def test_comm_widget_awareness_observe_and_unobserve(): + widget = CommWidget() + + events: list[str] = [] + sub_id = widget.awareness.observe(lambda topic, _: events.append(topic)) + + widget.awareness.set_local_state({"ping": 1}) + assert events + + widget.awareness.unobserve(sub_id) + events.clear() + widget.awareness.set_local_state({"ping": 2}) + assert events == []