diff --git a/packages/slackBotFunction/app/slack/slack_events.py b/packages/slackBotFunction/app/slack/slack_events.py index 9ba1d0b7e..66ef059b8 100644 --- a/packages/slackBotFunction/app/slack/slack_events.py +++ b/packages/slackBotFunction/app/slack/slack_events.py @@ -3,6 +3,7 @@ Handles conversation memory, Bedrock queries, and responding back to Slack """ +from decimal import Decimal import re import time import traceback @@ -401,6 +402,7 @@ def process_async_slack_event(event: Dict[str, Any], event_id: str, client: WebC conversation_key, thread_ts = conversation_key_and_root(event) user_id = event.get("user", "unknown") channel_id = event["channel"] + conversation_key, thread_root = conversation_key_and_root(event=event) if message_text.lower().startswith(constants.PULL_REQUEST_PREFIX): try: @@ -496,6 +498,77 @@ def process_pull_request_slack_action(slack_body_data: Dict[str, Any]) -> None: # ================================================================ +def _handle_modified_messages( + client: WebClient, + channel: str, + thread_ts: str, + original_ts: str, + edited_event: Dict[str, Any], + event_id: str, + user_id: str, +) -> bool: + """Clear subsequent bot replies in a thread when the original message is edited.""" + try: + logger.info("Existing conversation found", extra={"event": edited_event}) + current_thread = client.conversations_replies(channel=channel, ts=thread_ts) + + thread_messages = current_thread.get("messages", []) + + if len(thread_messages) > 1: + # Get the original message timestamp + previous_edit_ts = Decimal(original_ts) + + # Filter messages that came after the edited message + subsequent_messages = [msg for msg in thread_messages if Decimal(msg.get("ts", "0")) > previous_edit_ts] + + # Check if there are any subsequent messages from the user + has_user_replies = any(msg.get("user") == user_id for msg in subsequent_messages) + + if has_user_replies: + # Not the last user message in the chain + try: + client.chat_postEphemeral( + channel=channel, + user=user_id, + thread_ts=thread_ts, + text="It looks like the conversation has diverged, please start a new conversation", + ) + except Exception as e: + logger.error( + f"Couldn't post ephemeral message: {e}", + extra={"event_id": event_id, "error": traceback.format_exc()}, + ) + return False + else: + logger.info( + "Found existing thread, clearing replies", + extra={"channel": channel, "thread_ts": thread_ts, "previous_edit_ts": previous_edit_ts}, + ) + deleted_count = 0 + for reply in subsequent_messages: + reply_ts = reply.get("ts") + if reply_ts: + try: + client.chat_delete(channel=channel, ts=reply_ts) + deleted_count += 1 + except Exception as e: + logger.error( + f"Couldn't delete message: {e}", + extra={"event_id": event_id, "error": traceback.format_exc()}, + ) + + logger.info(f"Deleted {deleted_count} replies") + return True + + return True + except Exception as e: + logger.error( + f"Error modifying existing messages: {e}", + extra={"event_id": event_id, "error": traceback.format_exc()}, + ) + return False + + def process_slack_message(event: Dict[str, Any], event_id: str, client: WebClient) -> None: """ Process Slack events asynchronously after initial acknowledgment @@ -507,6 +580,7 @@ def process_slack_message(event: Dict[str, Any], event_id: str, client: WebClien user_id = event["user"] channel = event["channel"] conversation_key, thread_ts = conversation_key_and_root(event) + edited_event = event.get("edited") # Remove Slack user mentions from message text user_query = re.sub(r"<@[UW][A-Z0-9]+(\|[^>]+)?>", "", event["text"]).strip() @@ -524,6 +598,20 @@ def process_slack_message(event: Dict[str, Any], event_id: str, client: WebClien client.chat_postMessage(**post_params) return + if edited_event: + has_modified = _handle_modified_messages( + client=client, + channel=channel, + thread_ts=thread_ts, + original_ts=event["ts"], + edited_event=edited_event, + event_id=event_id, + user_id=user_id, + ) + + if not has_modified: + return + # conversation continuity: reuse bedrock session across slack messages session_data = get_conversation_session_data(conversation_key) session_id = session_data.get("session_id") if session_data else None @@ -556,7 +644,8 @@ def process_slack_message(event: Dict[str, Any], event_id: str, client: WebClien store_qa_pair(conversation_key, user_query, response_text, message_ts, kb_response.get("sessionId"), user_id) try: - client.chat_update(channel=channel, ts=message_ts, text=response_text, blocks=blocks) + response = client.chat_update(channel=channel, ts=message_ts, text=response_text, blocks=blocks) + logger.info("Chat Updated", extra={"response": response}) except Exception as e: logger.error( f"Failed to update message: {e}", diff --git a/packages/slackBotFunction/tests/test_slack_events/test_slack_events_messages.py b/packages/slackBotFunction/tests/test_slack_events/test_slack_events_messages.py index 0a142ce54..0b3aa94b0 100644 --- a/packages/slackBotFunction/tests/test_slack_events/test_slack_events_messages.py +++ b/packages/slackBotFunction/tests/test_slack_events/test_slack_events_messages.py @@ -1,6 +1,6 @@ import sys import pytest -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import ANY, Mock, patch, MagicMock @pytest.fixture @@ -866,3 +866,180 @@ def test_convert_markdown_to_slack_multiple_encoding_issues(mock_get_parameter: assert "ΓΆΒΆ" not in result assert "- Bullet point" in result assert "- another bullet" in result + + +# ================================================================ +# Tests for deleted messages after message edit +# ================================================================ + + +def test_handle_modified_messages_last_user_message_in_chain(): + """If the message is the last in the chain, delete reply and continue""" + if "app.slack.slack_events" in sys.modules: + del sys.modules["app.slack.slack_events"] + from app.slack.slack_events import _handle_modified_messages + + mock_client = Mock() + # Mock conversation history: Original user message + 2 Bot replies + mock_client.conversations_replies.return_value = { + "messages": [ + {"ts": "100.000", "user": "U123"}, + {"ts": "101.000", "user": "BOT_ID"}, + {"ts": "102.000", "user": "BOT_ID"}, + ] + } + + result = _handle_modified_messages( + client=mock_client, + channel="C123", + thread_ts="100.000", + original_ts="100.000", + edited_event={"ts": "105.000"}, + event_id="evt123", + user_id="U123", + ) + + # Assertions + assert result is True + assert mock_client.chat_delete.call_count == 2 + mock_client.chat_postEphemeral.assert_not_called() + + +def test_handle_modified_messages_not_last_user_message_in_chain(): + """If the message is not the last user message, post ephemeral msg and return False""" + if "app.slack.slack_events" in sys.modules: + del sys.modules["app.slack.slack_events"] + from app.slack.slack_events import _handle_modified_messages + + mock_client = Mock() + # Mock conversation history: Original message + Bot reply + Another user message (user replied again) + mock_client.conversations_replies.return_value = { + "messages": [ + {"ts": "100.000", "user": "U123"}, + {"ts": "101.000", "user": "BOT_ID"}, + {"ts": "102.000", "user": "U123"}, + ] + } + + result = _handle_modified_messages( + client=mock_client, + channel="C123", + thread_ts="100.000", + original_ts="100.000", + edited_event={"ts": "105.000"}, + event_id="evt123", + user_id="U123", + ) + + # Assertions + assert result is False + mock_client.chat_delete.assert_not_called() + mock_client.chat_postEphemeral.assert_called_once_with( + channel="C123", + user="U123", + thread_ts="100.000", + text="It looks like the conversation has diverged, please start a new conversation", + ) + + +def test_process_slack_message_halts_on_false_modified_handler( + mock_env: Mock, + mock_get_parameter: Mock, +): + """Test that process_slack_message stops processing if the edit is rejected (not last in chain)""" + mock_client = Mock() + + if "app.slack.slack_events" in sys.modules: + del sys.modules["app.slack.slack_events"] + from app.slack.slack_events import process_slack_message + + event = { + "text": "updated question", + "user": "U123", + "channel": "C123", + "ts": "123.123", + "edited": {"ts": "456.789"}, + "thread_ts": "123.123", + } + + # Patch dynamically inside the test body to avoid the module reload issue + with patch("app.slack.slack_events._handle_modified_messages") as mock_handle_modified_messages, patch( + "app.slack.slack_events.get_conversation_session_data" + ) as mock_get_conversation_session_data: + + mock_handle_modified_messages.return_value = False + + process_slack_message(event=event, event_id="evt123", client=mock_client) + + # Assertions + mock_handle_modified_messages.assert_called_once_with( + client=mock_client, + channel="C123", + thread_ts="123.123", + original_ts="123.123", + edited_event={"ts": "456.789"}, + event_id="evt123", + user_id="U123", + ) + + # Ensure it stopped execution and didn't try to fetch conversation memory / call Bedrock + mock_get_conversation_session_data.assert_not_called() + mock_client.chat_postMessage.assert_not_called() + + +def test_process_slack_message_continues_on_true_modified_handler(mock_env: Mock, mock_get_parameter: Mock): + """Test that process_slack_message completes if the edit is accepted""" + mock_client = Mock() + # Mock the response for the "Processing..." message + mock_client.chat_postMessage.return_value = {"ts": "999.999"} + + if "app.slack.slack_events" in sys.modules: + del sys.modules["app.slack.slack_events"] + from app.slack.slack_events import process_slack_message + + event = { + "text": "updated question", + "user": "U123", + "channel": "C123", + "ts": "123.123", + "edited": {"ts": "456.789"}, + "thread_ts": "123.123", + "event_ts": "123.123", # Required by log_query_stats + "channel_type": "channel", + } + + # Patch dynamically inside the test body to avoid the module reload issue wiping the mocks + with patch("app.slack.slack_events._handle_modified_messages") as mock_handle_modified_messages, patch( + "app.slack.slack_events.get_conversation_session_data" + ) as mock_get_conversation_session_data, patch( + "app.slack.slack_events.process_formatted_bedrock_query" + ) as mock_process_formatted_bedrock_query, patch( + "app.slack.slack_events._handle_session_management" + ), patch( + "app.slack.slack_events.store_qa_pair" + ), patch( + "app.slack.slack_events.log_query_stats" + ): + + mock_handle_modified_messages.return_value = True + + mock_get_conversation_session_data.return_value = {"session_id": "session-123"} + mock_process_formatted_bedrock_query.return_value = ({"sessionId": "session-123"}, "AI response", []) + + process_slack_message(event=event, event_id="evt123", client=mock_client) + + mock_handle_modified_messages.assert_called_once() + mock_get_conversation_session_data.assert_called_once() + + # Check it posted the "Processing..." message + mock_client.chat_postMessage.assert_called_once_with(channel="C123", text="Processing...", thread_ts="123.123") + + mock_process_formatted_bedrock_query.assert_called_once() + + # Check it updated the slack message + mock_client.chat_update.assert_called_once_with( + channel="C123", + ts="999.999", # This matches the mock_client.chat_postMessage.return_value + text="AI response", + blocks=ANY, + )