diff --git a/backend/chainlit/message.py b/backend/chainlit/message.py index 0700f630a7..4c19ed900e 100644 --- a/backend/chainlit/message.py +++ b/backend/chainlit/message.py @@ -138,10 +138,61 @@ async def remove(self): raise e logger.error(f"Failed to persist message deletion: {e!s}") + try: + await self.remove_children() + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist message children deletion: {e!s}") + await context.emitter.delete_step(step_dict) return True + async def remove_children(self): + data_layer = get_data_layer() + if not data_layer: + return + + thread = await data_layer.get_thread(self.thread_id) + if thread is None: + return + + steps = thread.get("steps", []) + + def collect_descendants(parent_id: str, visited: Optional[set] = None) -> list: + """Return descendant IDs in post-order (leaves first, parents last).""" + if visited is None: + visited = set() + if parent_id in visited: + return [] + visited.add(parent_id) + result = [] + for step in steps: + if step.get("parentId") == parent_id: + result.extend(collect_descendants(step["id"], visited)) + result.append(step["id"]) + return result + + # Ordered leaves-first so that referential integrity constraints are respected. + ordered_descendant_ids = collect_descendants(self.id) + descendant_set = set(ordered_descendant_ids) + + for step in steps: + step_id = step.get("id") + feedback_id = (step.get("feedback") or {}).get("id") + if step_id in descendant_set and feedback_id: + await data_layer.delete_feedback(feedback_id) + + for element in thread.get("elements", []): + for_id = element.get("forId") + element_id = element.get("id") + if for_id in descendant_set and element_id: + await data_layer.delete_element(element_id, self.thread_id) + + for step_id in ordered_descendant_ids: + await data_layer.delete_step(step_id) + async def _create(self): step_dict = self.to_dict() data_layer = get_data_layer() diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py index 740f0c276a..6da4d6102a 100644 --- a/backend/chainlit/socket.py +++ b/backend/chainlit/socket.py @@ -336,6 +336,7 @@ async def edit_message(sid, payload: MessagePayload): if message.id == payload["message"]["id"]: message.content = payload["message"]["output"] await message.update() + await message.remove_children() orig_message = message await context.emitter.task_start() diff --git a/backend/tests/test_message.py b/backend/tests/test_message.py index 952f557414..241e58a4ba 100644 --- a/backend/tests/test_message.py +++ b/backend/tests/test_message.py @@ -755,3 +755,185 @@ def test_message_to_dict_with_none_metadata(self): result = msg.to_dict() assert result["metadata"] == {} + + +class TestRemoveChildren: + """Test suite for Message.remove_children.""" + + def _make_message(self, msg_id="msg_1", thread_id="thread_1"): + with mock_chainlit_context(): + msg = Message(content="test") + msg.id = msg_id + msg.thread_id = thread_id + return msg + + def _tracked_data_layer(self, get_thread_result): + """ + Data layer mock whose delete_* / get_thread record only when awaited + (misses missing-await bugs). ``events`` is the strict call sequence. + """ + events: list[tuple] = [] + + async def get_thread(thread_id): + events.append(("get_thread", thread_id)) + return get_thread_result + + async def delete_feedback(feedback_id): + events.append(("delete_feedback", feedback_id)) + + async def delete_element(element_id, thread_id=None): + events.append(("delete_element", element_id, thread_id)) + + async def delete_step(step_id): + events.append(("delete_step", step_id)) + + mock_layer = AsyncMock() + mock_layer.get_thread = AsyncMock(side_effect=get_thread) + mock_layer.delete_feedback = AsyncMock(side_effect=delete_feedback) + mock_layer.delete_element = AsyncMock(side_effect=delete_element) + mock_layer.delete_step = AsyncMock(side_effect=delete_step) + return mock_layer, events + + @pytest.mark.asyncio + async def test_no_data_layer(self): + """Does nothing when there is no data layer.""" + msg = self._make_message() + with patch("chainlit.message.get_data_layer", return_value=None): + await msg.remove_children() + + @pytest.mark.asyncio + async def test_thread_not_found(self): + """Does nothing when the thread does not exist.""" + msg = self._make_message() + mock_data_layer, events = self._tracked_data_layer(None) + + with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): + await msg.remove_children() + + assert events == [("get_thread", "thread_1")] + + @pytest.mark.asyncio + async def test_no_children(self): + """Does nothing when the message has no child steps.""" + msg = self._make_message() + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + {"id": "other_msg", "parentId": None}, + ], + "elements": [], + } + mock_data_layer, events = self._tracked_data_layer(thread) + + with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): + await msg.remove_children() + + assert events == [("get_thread", "thread_1")] + + @pytest.mark.asyncio + async def test_direct_children_deleted(self): + """Deletes direct children of the message.""" + msg = self._make_message() + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + {"id": "child_1", "parentId": "msg_1"}, + {"id": "child_2", "parentId": "msg_1"}, + ], + "elements": [], + } + mock_data_layer, events = self._tracked_data_layer(thread) + + with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): + await msg.remove_children() + + assert events == [ + ("get_thread", "thread_1"), + ("delete_step", "child_1"), + ("delete_step", "child_2"), + ] + + @pytest.mark.asyncio + async def test_nested_descendants_deleted(self): + """Recursively deletes grandchildren and deeper descendants.""" + msg = self._make_message() + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + {"id": "child_1", "parentId": "msg_1"}, + {"id": "grandchild_1", "parentId": "child_1"}, + {"id": "great_grandchild_1", "parentId": "grandchild_1"}, + {"id": "unrelated", "parentId": None}, + ], + "elements": [], + } + mock_data_layer, events = self._tracked_data_layer(thread) + + with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): + await msg.remove_children() + + assert events == [ + ("get_thread", "thread_1"), + ("delete_step", "great_grandchild_1"), + ("delete_step", "grandchild_1"), + ("delete_step", "child_1"), + ] + + @pytest.mark.asyncio + async def test_feedback_and_elements_deleted_before_steps(self): + """Removes feedback and elements for descendants before delete_step.""" + msg = self._make_message() + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + { + "id": "child_1", + "parentId": "msg_1", + "feedback": { + "id": "fb_1", + "forId": "child_1", + "value": 1, + "comment": None, + }, + }, + {"id": "child_2", "parentId": "msg_1"}, + ], + "elements": [ + {"id": "el_1", "forId": "child_1", "threadId": "thread_1"}, + {"id": "el_other", "forId": "msg_1", "threadId": "thread_1"}, + {"id": "el_unrelated", "forId": "other_root", "threadId": "thread_1"}, + ], + } + mock_data_layer, events = self._tracked_data_layer(thread) + + with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): + await msg.remove_children() + + assert events == [ + ("get_thread", "thread_1"), + ("delete_feedback", "fb_1"), + ("delete_element", "el_1", "thread_1"), + ("delete_step", "child_1"), + ("delete_step", "child_2"), + ] + + @pytest.mark.asyncio + async def test_message_itself_is_not_deleted(self): + """The root message itself is never deleted, only its descendants.""" + msg = self._make_message() + thread = { + "steps": [ + {"id": "msg_1", "parentId": None}, + {"id": "child_1", "parentId": "msg_1"}, + ], + "elements": [], + } + mock_data_layer, events = self._tracked_data_layer(thread) + + with patch("chainlit.message.get_data_layer", return_value=mock_data_layer): + await msg.remove_children() + + assert events == [ + ("get_thread", "thread_1"), + ("delete_step", "child_1"), + ]