Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions backend/chainlit/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,61 @@ async def remove(self):
raise e
logger.error(f"Failed to persist message deletion: {e!s}")

try:
await self.remove_children()
except Exception as e:
if self.fail_on_persist_error:
raise e
logger.error(f"Failed to persist message children deletion: {e!s}")

await context.emitter.delete_step(step_dict)

return True

async def remove_children(self):
data_layer = get_data_layer()
if not data_layer:
return

thread = await data_layer.get_thread(self.thread_id)
if thread is None:
return

steps = thread.get("steps", [])

def collect_descendants(parent_id: str, visited: Optional[set] = None) -> list:
"""Return descendant IDs in post-order (leaves first, parents last)."""
if visited is None:
visited = set()
if parent_id in visited:
return []
visited.add(parent_id)
result = []
for step in steps:
if step.get("parentId") == parent_id:
result.extend(collect_descendants(step["id"], visited))
result.append(step["id"])
return result

# Ordered leaves-first so that referential integrity constraints are respected.
ordered_descendant_ids = collect_descendants(self.id)
descendant_set = set(ordered_descendant_ids)

for step in steps:
step_id = step.get("id")
feedback_id = (step.get("feedback") or {}).get("id")
if step_id in descendant_set and feedback_id:
await data_layer.delete_feedback(feedback_id)

for element in thread.get("elements", []):
for_id = element.get("forId")
element_id = element.get("id")
if for_id in descendant_set and element_id:
await data_layer.delete_element(element_id, self.thread_id)

for step_id in ordered_descendant_ids:
await data_layer.delete_step(step_id)

async def _create(self):
step_dict = self.to_dict()
data_layer = get_data_layer()
Expand Down
1 change: 1 addition & 0 deletions backend/chainlit/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ async def edit_message(sid, payload: MessagePayload):
if message.id == payload["message"]["id"]:
message.content = payload["message"]["output"]
await message.update()
await message.remove_children()
orig_message = message

await context.emitter.task_start()
Expand Down
182 changes: 182 additions & 0 deletions backend/tests/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,3 +755,185 @@ def test_message_to_dict_with_none_metadata(self):
result = msg.to_dict()

assert result["metadata"] == {}


class TestRemoveChildren:
"""Test suite for Message.remove_children."""

def _make_message(self, msg_id="msg_1", thread_id="thread_1"):
with mock_chainlit_context():
msg = Message(content="test")
msg.id = msg_id
msg.thread_id = thread_id
return msg

def _tracked_data_layer(self, get_thread_result):
"""
Data layer mock whose delete_* / get_thread record only when awaited
(misses missing-await bugs). ``events`` is the strict call sequence.
"""
events: list[tuple] = []

async def get_thread(thread_id):
events.append(("get_thread", thread_id))
return get_thread_result

async def delete_feedback(feedback_id):
events.append(("delete_feedback", feedback_id))

async def delete_element(element_id, thread_id=None):
events.append(("delete_element", element_id, thread_id))

async def delete_step(step_id):
events.append(("delete_step", step_id))

mock_layer = AsyncMock()
mock_layer.get_thread = AsyncMock(side_effect=get_thread)
mock_layer.delete_feedback = AsyncMock(side_effect=delete_feedback)
mock_layer.delete_element = AsyncMock(side_effect=delete_element)
mock_layer.delete_step = AsyncMock(side_effect=delete_step)
return mock_layer, events

@pytest.mark.asyncio
async def test_no_data_layer(self):
"""Does nothing when there is no data layer."""
msg = self._make_message()
with patch("chainlit.message.get_data_layer", return_value=None):
await msg.remove_children()

@pytest.mark.asyncio
async def test_thread_not_found(self):
"""Does nothing when the thread does not exist."""
msg = self._make_message()
mock_data_layer, events = self._tracked_data_layer(None)

with patch("chainlit.message.get_data_layer", return_value=mock_data_layer):
await msg.remove_children()

assert events == [("get_thread", "thread_1")]

@pytest.mark.asyncio
async def test_no_children(self):
"""Does nothing when the message has no child steps."""
msg = self._make_message()
thread = {
"steps": [
{"id": "msg_1", "parentId": None},
{"id": "other_msg", "parentId": None},
],
"elements": [],
}
mock_data_layer, events = self._tracked_data_layer(thread)

with patch("chainlit.message.get_data_layer", return_value=mock_data_layer):
await msg.remove_children()

assert events == [("get_thread", "thread_1")]

@pytest.mark.asyncio
async def test_direct_children_deleted(self):
"""Deletes direct children of the message."""
msg = self._make_message()
thread = {
"steps": [
{"id": "msg_1", "parentId": None},
{"id": "child_1", "parentId": "msg_1"},
{"id": "child_2", "parentId": "msg_1"},
],
"elements": [],
}
mock_data_layer, events = self._tracked_data_layer(thread)

with patch("chainlit.message.get_data_layer", return_value=mock_data_layer):
await msg.remove_children()

assert events == [
("get_thread", "thread_1"),
("delete_step", "child_1"),
("delete_step", "child_2"),
]

@pytest.mark.asyncio
async def test_nested_descendants_deleted(self):
"""Recursively deletes grandchildren and deeper descendants."""
msg = self._make_message()
thread = {
"steps": [
{"id": "msg_1", "parentId": None},
{"id": "child_1", "parentId": "msg_1"},
{"id": "grandchild_1", "parentId": "child_1"},
{"id": "great_grandchild_1", "parentId": "grandchild_1"},
{"id": "unrelated", "parentId": None},
],
"elements": [],
}
mock_data_layer, events = self._tracked_data_layer(thread)

with patch("chainlit.message.get_data_layer", return_value=mock_data_layer):
await msg.remove_children()

assert events == [
("get_thread", "thread_1"),
("delete_step", "great_grandchild_1"),
("delete_step", "grandchild_1"),
("delete_step", "child_1"),
]

@pytest.mark.asyncio
async def test_feedback_and_elements_deleted_before_steps(self):
"""Removes feedback and elements for descendants before delete_step."""
msg = self._make_message()
thread = {
"steps": [
{"id": "msg_1", "parentId": None},
{
"id": "child_1",
"parentId": "msg_1",
"feedback": {
"id": "fb_1",
"forId": "child_1",
"value": 1,
"comment": None,
},
},
{"id": "child_2", "parentId": "msg_1"},
],
"elements": [
{"id": "el_1", "forId": "child_1", "threadId": "thread_1"},
{"id": "el_other", "forId": "msg_1", "threadId": "thread_1"},
{"id": "el_unrelated", "forId": "other_root", "threadId": "thread_1"},
],
}
mock_data_layer, events = self._tracked_data_layer(thread)

with patch("chainlit.message.get_data_layer", return_value=mock_data_layer):
await msg.remove_children()

assert events == [
("get_thread", "thread_1"),
("delete_feedback", "fb_1"),
("delete_element", "el_1", "thread_1"),
("delete_step", "child_1"),
("delete_step", "child_2"),
]

@pytest.mark.asyncio
async def test_message_itself_is_not_deleted(self):
"""The root message itself is never deleted, only its descendants."""
msg = self._make_message()
thread = {
"steps": [
{"id": "msg_1", "parentId": None},
{"id": "child_1", "parentId": "msg_1"},
],
"elements": [],
}
mock_data_layer, events = self._tracked_data_layer(thread)

with patch("chainlit.message.get_data_layer", return_value=mock_data_layer):
await msg.remove_children()

assert events == [
("get_thread", "thread_1"),
("delete_step", "child_1"),
]
Loading