diff --git a/django_app/tests/test_consumers.py b/django_app/tests/test_consumers.py index 55cc84245..35c0d4df3 100644 --- a/django_app/tests/test_consumers.py +++ b/django_app/tests/test_consumers.py @@ -67,7 +67,7 @@ async def test_chat_consumer_with_new_session( communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/") communicator.scope["user"] = alice - connected, _ = await communicator.connect() + connected, _ = await communicator.connect(timeout=5) assert connected with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=mocked_connect): await communicator.send_json_to({"message": "Hello Hal."}) @@ -908,7 +908,7 @@ async def test_connect_with_agents_update_via_db(agents_list: list, alice: User) mock_get.return_value = agents_list communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/") communicator.scope["user"] = alice - await communicator.connect() + await communicator.connect(timeout=5) assert mock_get.call_count == 1 assert "Fake_Agent" not in list(ChatConsumer.redbox.agent_configs.keys()) diff --git a/redbox/redbox/api/format.py b/redbox/redbox/api/format.py index 76364d26c..554ddbb98 100644 --- a/redbox/redbox/api/format.py +++ b/redbox/redbox/api/format.py @@ -69,6 +69,9 @@ def format_mcp_tool_response(tool_response, creator_type: ChunkCreatorType) -> t deep_links = [] match result_type: case "nullable": + if isinstance(result, str): + return result, metadata + deep_links = [(result.get("url"), result)] case "paged": deep_links = [(p.get("url"), p) for p in result.get("items", [])] @@ -99,4 +102,7 @@ def format_mcp_tool_response(tool_response, creator_type: ChunkCreatorType) -> t else: response.append(json.dumps(item)) + if not response: + return ("No results found.", metadata) + return ("\n\n".join(response), metadata) diff --git a/redbox/redbox/graph/nodes/processes.py b/redbox/redbox/graph/nodes/processes.py index bf7fa38cf..a0fa46c1d 100644 --- a/redbox/redbox/graph/nodes/processes.py +++ b/redbox/redbox/graph/nodes/processes.py @@ -699,22 +699,41 @@ def local_loop_condition(): result = "Tool error: no results received." elif has_loop and len(ai_msg.tool_calls) > 0: # if loop, we need to transform results - result = result[-1].content # this is a tuple - # format of result: (result, success, is_intermediate_step) - log.warning("my-overall-result") - log.warning(result) - result_content = result[0] - success = result[1] - is_intermediate_step = eval(result[2]) + collated_result = "" + feedback_reasons = [] + for i, r in enumerate(result): + current_result = r.content # this is a tuple + # format of result: (result, success, is_intermediate_step) + log.debug("my-overall-result") + log.debug(current_result) + result_content = current_result[0] + success = current_result[1] + is_intermediate_step = eval(current_result[2]) + + if len(current_result) > 3: + reason = current_result[3] + feedback_reasons.append(f"Failure reason: {reason}.\n\n{result_content}") + else: + collated_result += f"{result_content}" + + if success == "fail": + # pass error back if any + additional_variables.update({"previous_tool_error": result_content}) + else: + # if success tool invocation, and intermediate steps then pass info back + if is_intermediate_step: + additional_variables.update( + {"previous_tool_error": "", "previous_tool_results": all_results} + ) - if len(result) > 3: - reason = result[3] + if feedback_reasons: + combined_feedback = "\n\n".join(feedback_reasons) return { "agents_results": { task.id: AIMessage( - content=f"<{agent_name}_Result>Ask user for feedback based on failure reason. Failure reason: {reason}.\n\n{result_content}", + content=f"<{agent_name}_Result>Ask user for feedback based on failure reason. {combined_feedback}\n\n{collated_result}", kwargs={ - "reason": reason, + "reason": combined_feedback, }, ) }, @@ -722,14 +741,7 @@ def local_loop_condition(): "agent_plans": state.agent_plans.update_task_status(task.id, TaskStatus.REQUIRES_USER_FEEDBACK), } - if success == "fail": - # pass error back if any - additional_variables.update({"previous_tool_error": result_content}) - else: - # if success tool invocation, and intermediate steps then pass info back - if is_intermediate_step: - additional_variables.update({"previous_tool_error": "", "previous_tool_results": all_results}) - result = result_content + result = collated_result if isinstance(result, str): log.warning(f"{log_stub} Using raw string result.") diff --git a/redbox/redbox/graph/nodes/sends.py b/redbox/redbox/graph/nodes/sends.py index de3412c27..71adadf11 100644 --- a/redbox/redbox/graph/nodes/sends.py +++ b/redbox/redbox/graph/nodes/sends.py @@ -127,24 +127,23 @@ def wrap_async_tool(tool, tool_name): """ def wrapper(args): - # Create a new event loop for this thread - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - # get mcp tool url mcp_url = tool.metadata["url"] creator_type = tool.metadata["creator_type"] - sso_access_token = tool.metadata["sso_access_token"].get() + + try: + sso_access_token = tool.metadata["sso_access_token"].get() + except Exception as e: + log.error(f"wrap_async_tool - Failed to retrieve sso_access_token: {e}") + raise if not sso_access_token: log.error("wrap_async_tool - MCP sso_access_token is None") headers = _get_mcp_headers(sso_access_token) - try: - # Define the async operation - async def run_tool(): - # tool need to be executed within the connection context manager + async def run_tool(): + try: async with streamablehttp_client(mcp_url, headers=headers or None) as ( read, write, @@ -168,7 +167,7 @@ async def run_tool(): raise ValueError(f"tool with name '{tool_name}' not found") # remove intermediate step argument if it is not required by tool - if "is_intermediate_step" not in selected_tool.args_schema["required"] and args.get( + if "is_intermediate_step" not in selected_tool.args_schema.get("required", []) and args.get( "is_intermediate_step" ): args.pop("is_intermediate_step") @@ -193,12 +192,15 @@ async def run_tool(): f"wrap_async_tool - Returning raw MCP tool response for creator_type='{creator_type}'" ) return result + except Exception as e: + log.error(f"wrap_async_tool - Failed to connect to MCP server at '{mcp_url}': {e}") + raise - # Run the async function and return its result - return loop.run_until_complete(run_tool()) - finally: - # Clean up resources - loop.close() + try: + return asyncio.run(run_tool()) + except Exception as e: + log.error(f"wrap_async_tool - Unhandled error running tool '{tool_name}': {e}", exc_info=True) + raise return wrapper diff --git a/redbox/tests/graph/nodes/test_sends.py b/redbox/tests/graph/nodes/test_sends.py index c5e310bc9..e70faaaa5 100644 --- a/redbox/tests/graph/nodes/test_sends.py +++ b/redbox/tests/graph/nodes/test_sends.py @@ -561,6 +561,47 @@ def test_returns_expected_results( assert result == expected_documents assert metadata == expected_tool_metadata + @pytest.mark.parametrize("expected_tool_result, expected_documents", MCP_TOOL_RESULTS) + @patch("redbox.graph.nodes.sends.ClientSession") + @patch("redbox.graph.nodes.sends.streamablehttp_client") + @patch("redbox.graph.nodes.sends.load_mcp_tools") + def test_returns_expected_results_no_args( + self, + mock_load_tools, + mock_http_client, + mock_session_class, + fake_mcp_tool, + expected_tool_result, + expected_documents, + ): + """Test that wrap_async_tool correctly returns results from async tool invocation""" + expected_tool_content, expected_tool_metadata = expected_tool_result + + # Mock tool with metadata + tool_name = "company_tool" + args_schema = {} + tool = fake_mcp_tool(tool_name, return_value=expected_tool_content, args_schema=args_schema) + + # mock session with patched mcp setup + mock_session = self._patch_mcp_env(mock_load_tools, mock_http_client, mock_session_class, [tool]) + + # create the wrapped function + wrapped_func = wrap_async_tool(tool, tool_name) + + # rest invocation with sample args + test_args = {} + result, metadata = wrapped_func(test_args) + + # verify correct interactions + mock_http_client.assert_called_once_with(tool.metadata["url"], headers=None) + mock_session.initialize.assert_called_once() + mock_load_tools.assert_called_once_with(mock_session) + tool.ainvoke.assert_called_once_with(test_args) + + # assert the result matches our expected output + assert result == expected_documents + assert metadata == expected_tool_metadata + @patch("redbox.graph.nodes.sends.ClientSession") @patch("redbox.graph.nodes.sends.streamablehttp_client") @patch("redbox.graph.nodes.sends.load_mcp_tools", new_callable=AsyncMock) @@ -576,6 +617,87 @@ def test_tool_not_found(self, mock_load_tools, mock_http_client, mock_session_cl with pytest.raises(ValueError, match="tool with name 'missing_tool' not found"): wrapped_func({"foo": "bar"}) + def test_sso_token_retrieval_failure(self, fake_mcp_tool): + """Test that wrap_async_tool raises when sso_access_token.get() fails.""" + tool = fake_mcp_tool("dummy_tool", return_value=None) + tool.metadata["sso_access_token"] = MagicMock() + tool.metadata["sso_access_token"].get.side_effect = RuntimeError("vault unavailable") + + wrapped_func = wrap_async_tool(tool, "dummy_tool") + + with pytest.raises(RuntimeError, match="vault unavailable"): + wrapped_func({"foo": "bar"}) + + @patch("redbox.graph.nodes.sends.ClientSession") + @patch("redbox.graph.nodes.sends.streamablehttp_client") + @patch("redbox.graph.nodes.sends.load_mcp_tools") + def test_intermediate_step_stripped_when_not_in_schema( + self, + mock_load_tools, + mock_http_client, + mock_session_class, + fake_mcp_tool, + ): + """Test that is_intermediate_step is removed from args when not required by the tool schema.""" + return_value = "some content" + tool_name = "company_tool" + args_schema = {"company_name": {"type": "string"}, "required": ["company_name"]} + tool = fake_mcp_tool(tool_name, return_value=return_value, args_schema=args_schema) + + # make ainvoke return something format_mcp_tool_response can handle, or use a non-datahub type + tool.metadata["creator_type"] = MagicMock() # non-datahub, returns raw result + tool.ainvoke = AsyncMock(return_value=return_value) + + self._patch_mcp_env(mock_load_tools, mock_http_client, mock_session_class, [tool]) + + wrapped_func = wrap_async_tool(tool, tool_name) + + test_args = {"company_name": "BMW", "is_intermediate_step": True} + wrapped_func(test_args) + + # is_intermediate_step should have been popped before ainvoke was called + tool.ainvoke.assert_called_once_with({"company_name": "BMW"}) + + @patch("redbox.graph.nodes.sends.ClientSession") + @patch("redbox.graph.nodes.sends.streamablehttp_client") + @patch("redbox.graph.nodes.sends.load_mcp_tools") + def test_intermediate_step_retained_when_in_schema( + self, + mock_load_tools, + mock_http_client, + mock_session_class, + fake_mcp_tool, + ): + """Test that is_intermediate_step is kept in args when the tool schema requires it.""" + return_value = "some content" + tool_name = "company_tool" + args_schema = {"company_name": {"type": "string"}, "required": ["company_name", "is_intermediate_step"]} + tool = fake_mcp_tool(tool_name, return_value=return_value, args_schema=args_schema) + + tool.metadata["creator_type"] = MagicMock() + tool.ainvoke = AsyncMock(return_value=return_value) + + self._patch_mcp_env(mock_load_tools, mock_http_client, mock_session_class, [tool]) + + wrapped_func = wrap_async_tool(tool, tool_name) + + test_args = {"company_name": "BMW", "is_intermediate_step": True} + wrapped_func(test_args) + + # is_intermediate_step should be preserved + tool.ainvoke.assert_called_once_with({"company_name": "BMW", "is_intermediate_step": True}) + + @patch("redbox.graph.nodes.sends.asyncio") + def test_asyncio_run_failure(self, mock_asyncio, fake_mcp_tool): + """Test that wrap_async_tool re-raises when asyncio.run itself fails.""" + tool = fake_mcp_tool("dummy_tool", return_value=None) + mock_asyncio.run.side_effect = RuntimeError("event loop closed") + + wrapped_func = wrap_async_tool(tool, "dummy_tool") + + with pytest.raises(RuntimeError, match="event loop closed"): + wrapped_func({"foo": "bar"}) + @pytest.mark.parametrize( "token_input, expected_output", diff --git a/redbox/tests/graph/test_patterns.py b/redbox/tests/graph/test_patterns.py index a5ed48c46..dac6df798 100644 --- a/redbox/tests/graph/test_patterns.py +++ b/redbox/tests/graph/test_patterns.py @@ -1017,6 +1017,158 @@ async def test_llm_response_truncation( assert len(response) == 3 # content, eval task, task status assert mock_tool_calls.call_count == 1 + @pytest.mark.parametrize( + "test_name, tool_results, expect_feedback, expected_reasons", + [ + ( + "feedback-required-single-tool", + [ + AIMessage( + content=( + "some results", + "pass", + "False", + "Multiple company records returned user must clarify which one they want.", + ) + ) + ], + True, + ["Multiple company records returned user must clarify which one they want."], + ), + ( + "feedback-required-multiple-tools", + [ + AIMessage(content=("results-1", "pass", "False", "Multiple companies found, clarify which.")), + AIMessage(content=("results-2", "pass", "False", "Multiple interactions found, clarify which.")), + ], + True, + ["Multiple companies found, clarify which.", "Multiple interactions found, clarify which."], + ), + ( + "feedback-required-mixed-tools", + [ + AIMessage(content=("results-1", "pass", "False", "Clarification needed.")), + AIMessage(content=("results-2", "pass", "False")), + ], + True, + ["Clarification needed."], + ), + ( + "no-feedback-pass", + [ + AIMessage(content=("clean result", "pass", "False")), + ], + False, + [], + ), + ( + "no-feedback-fail-no-reason", + [ + AIMessage(content=("failed result", "fail", "False")), + ], + False, + [], + ), + ( + "no-feedback-multiple-pass", + [ + AIMessage(content=("result-1", "pass", "False")), + AIMessage(content=("result-2", "pass", "False")), + ], + False, + [], + ), + ], + ) + @pytest.mark.asyncio + async def test_feedback_required_vs_not( + self, + test_name, + tool_results, + expect_feedback, + expected_reasons, + fake_state, + mocker: MockerFixture, + mock_datahub_tools, + ): + """Test agent returns REQUIRES_USER_FEEDBACK when any tool result has a reason, otherwise COMPLETED.""" + res = AIMessage( + content="test", + tool_calls=[ + {"name": "test_tool", "args": {"is_intermediate_step": False}, "id": "fake-id", "type": "tool_call"} + ], + ) + llm = GenericFakeChatModel(messages=iter([res] * 10)) + mocker.patch("redbox.chains.runnables.get_chat_llm", return_value=llm) + + mock_tool_calls = mocker.patch("redbox.graph.nodes.processes.run_tools_parallel") + mock_tool_calls.return_value = tool_results + + agent_name = "Internal_Retrieval_Agent" + agent_task, multi_agent_plan = configure_agent_task_plan({agent_name: agent_name}) + tasks = [agent_task(task="Fake Task", expected_output="Fake output")] + plan = multi_agent_plan().model_copy(update={"tasks": tasks}) + + fake_state.user_feedback = "proceed" + fake_state.agent_plans = plan + fake_state.tasks_evaluator = "" + fake_state.messages = [AIMessage(content=plan.tasks[0].model_dump_json())] + fake_state.request.sso_token_getter = lambda: "fake-token" + + fake_agent = build_datahub_agent_with_loop( + agent_name=agent_name, + system_prompt="Fake prompt", + tools=[], + use_metadata=False, + max_tokens=10000, + pre_process=None, + loop_condition=lambda: True, + max_attempt=2, + ) + + response = await fake_agent.ainvoke(fake_state) + + assert response is not None + assert "agents_results" in response + assert "agent_plans" in response + + result_message = response["agents_results"]["task0"] + + if expect_feedback: + assert response["agent_plans"].get_task_status("task0") == TaskStatus.REQUIRES_USER_FEEDBACK + assert "Ask user for feedback based on failure reason" in result_message.content + for reason in expected_reasons: + assert reason in result_message.content + + collated_result = "" + feedback_reasons = [] + for i, tr in enumerate(tool_results): + if len(tr.content) > 3: + reason = tr.content[3] + feedback_reasons.append(f"Failure reason: {reason}.\n\n{tr.content[0]}") + else: + collated_result += f"{tr.content[0]}" + + combined_feedback = "\n\n".join(feedback_reasons) + + assert ( + result_message.content + == f"<{agent_name}_Result>Ask user for feedback based on failure reason. {combined_feedback}\n\n{collated_result}" + ) + else: + assert response["agent_plans"].get_task_status("task0") == TaskStatus.COMPLETED + assert "Ask user for feedback based on failure reason" not in result_message.content + + for tr in tool_results: + assert tr.content[0] in result_message.content + + collated_result = "" + for i, tr in enumerate(tool_results): + collated_result += f"{tr.content[0]}" + loop_result = " ".join([collated_result, collated_result]) + + assert result_message.content == f"<{agent_name}_Result>{loop_result}" + @pytest.mark.parametrize( "task_idx, task_status, expected",