diff --git a/django_app/tests/test_consumers.py b/django_app/tests/test_consumers.py
index 55cc84245..35c0d4df3 100644
--- a/django_app/tests/test_consumers.py
+++ b/django_app/tests/test_consumers.py
@@ -67,7 +67,7 @@ async def test_chat_consumer_with_new_session(
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = alice
- connected, _ = await communicator.connect()
+ connected, _ = await communicator.connect(timeout=5)
assert connected
with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=mocked_connect):
await communicator.send_json_to({"message": "Hello Hal."})
@@ -908,7 +908,7 @@ async def test_connect_with_agents_update_via_db(agents_list: list, alice: User)
mock_get.return_value = agents_list
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = alice
- await communicator.connect()
+ await communicator.connect(timeout=5)
assert mock_get.call_count == 1
assert "Fake_Agent" not in list(ChatConsumer.redbox.agent_configs.keys())
diff --git a/redbox/redbox/api/format.py b/redbox/redbox/api/format.py
index 76364d26c..554ddbb98 100644
--- a/redbox/redbox/api/format.py
+++ b/redbox/redbox/api/format.py
@@ -69,6 +69,9 @@ def format_mcp_tool_response(tool_response, creator_type: ChunkCreatorType) -> t
deep_links = []
match result_type:
case "nullable":
+ if isinstance(result, str):
+ return result, metadata
+
deep_links = [(result.get("url"), result)]
case "paged":
deep_links = [(p.get("url"), p) for p in result.get("items", [])]
@@ -99,4 +102,7 @@ def format_mcp_tool_response(tool_response, creator_type: ChunkCreatorType) -> t
else:
response.append(json.dumps(item))
+ if not response:
+ return ("No results found.", metadata)
+
return ("\n\n".join(response), metadata)
diff --git a/redbox/redbox/graph/nodes/processes.py b/redbox/redbox/graph/nodes/processes.py
index bf7fa38cf..a0fa46c1d 100644
--- a/redbox/redbox/graph/nodes/processes.py
+++ b/redbox/redbox/graph/nodes/processes.py
@@ -699,22 +699,41 @@ def local_loop_condition():
result = "Tool error: no results received."
elif has_loop and len(ai_msg.tool_calls) > 0: # if loop, we need to transform results
- result = result[-1].content # this is a tuple
- # format of result: (result, success, is_intermediate_step)
- log.warning("my-overall-result")
- log.warning(result)
- result_content = result[0]
- success = result[1]
- is_intermediate_step = eval(result[2])
+ collated_result = ""
+ feedback_reasons = []
+ for i, r in enumerate(result):
+ current_result = r.content # this is a tuple
+ # format of result: (result, success, is_intermediate_step)
+ log.debug("my-overall-result")
+ log.debug(current_result)
+ result_content = current_result[0]
+ success = current_result[1]
+ is_intermediate_step = eval(current_result[2])
+
+ if len(current_result) > 3:
+ reason = current_result[3]
+ feedback_reasons.append(f"Failure reason: {reason}.\n\n{result_content}")
+ else:
+ collated_result += f"{result_content}"
+
+ if success == "fail":
+ # pass error back if any
+ additional_variables.update({"previous_tool_error": result_content})
+ else:
+ # if success tool invocation, and intermediate steps then pass info back
+ if is_intermediate_step:
+ additional_variables.update(
+ {"previous_tool_error": "", "previous_tool_results": all_results}
+ )
- if len(result) > 3:
- reason = result[3]
+ if feedback_reasons:
+ combined_feedback = "\n\n".join(feedback_reasons)
return {
"agents_results": {
task.id: AIMessage(
- content=f"<{agent_name}_Result>Ask user for feedback based on failure reason. Failure reason: {reason}.\n\n{result_content}{agent_name}_Result>",
+ content=f"<{agent_name}_Result>Ask user for feedback based on failure reason. {combined_feedback}\n\n{collated_result}{agent_name}_Result>",
kwargs={
- "reason": reason,
+ "reason": combined_feedback,
},
)
},
@@ -722,14 +741,7 @@ def local_loop_condition():
"agent_plans": state.agent_plans.update_task_status(task.id, TaskStatus.REQUIRES_USER_FEEDBACK),
}
- if success == "fail":
- # pass error back if any
- additional_variables.update({"previous_tool_error": result_content})
- else:
- # if success tool invocation, and intermediate steps then pass info back
- if is_intermediate_step:
- additional_variables.update({"previous_tool_error": "", "previous_tool_results": all_results})
- result = result_content
+ result = collated_result
if isinstance(result, str):
log.warning(f"{log_stub} Using raw string result.")
diff --git a/redbox/redbox/graph/nodes/sends.py b/redbox/redbox/graph/nodes/sends.py
index de3412c27..71adadf11 100644
--- a/redbox/redbox/graph/nodes/sends.py
+++ b/redbox/redbox/graph/nodes/sends.py
@@ -127,24 +127,23 @@ def wrap_async_tool(tool, tool_name):
"""
def wrapper(args):
- # Create a new event loop for this thread
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
# get mcp tool url
mcp_url = tool.metadata["url"]
creator_type = tool.metadata["creator_type"]
- sso_access_token = tool.metadata["sso_access_token"].get()
+
+ try:
+ sso_access_token = tool.metadata["sso_access_token"].get()
+ except Exception as e:
+ log.error(f"wrap_async_tool - Failed to retrieve sso_access_token: {e}")
+ raise
if not sso_access_token:
log.error("wrap_async_tool - MCP sso_access_token is None")
headers = _get_mcp_headers(sso_access_token)
- try:
- # Define the async operation
- async def run_tool():
- # tool need to be executed within the connection context manager
+ async def run_tool():
+ try:
async with streamablehttp_client(mcp_url, headers=headers or None) as (
read,
write,
@@ -168,7 +167,7 @@ async def run_tool():
raise ValueError(f"tool with name '{tool_name}' not found")
# remove intermediate step argument if it is not required by tool
- if "is_intermediate_step" not in selected_tool.args_schema["required"] and args.get(
+ if "is_intermediate_step" not in selected_tool.args_schema.get("required", []) and args.get(
"is_intermediate_step"
):
args.pop("is_intermediate_step")
@@ -193,12 +192,15 @@ async def run_tool():
f"wrap_async_tool - Returning raw MCP tool response for creator_type='{creator_type}'"
)
return result
+ except Exception as e:
+ log.error(f"wrap_async_tool - Failed to connect to MCP server at '{mcp_url}': {e}")
+ raise
- # Run the async function and return its result
- return loop.run_until_complete(run_tool())
- finally:
- # Clean up resources
- loop.close()
+ try:
+ return asyncio.run(run_tool())
+ except Exception as e:
+ log.error(f"wrap_async_tool - Unhandled error running tool '{tool_name}': {e}", exc_info=True)
+ raise
return wrapper
diff --git a/redbox/tests/graph/nodes/test_sends.py b/redbox/tests/graph/nodes/test_sends.py
index c5e310bc9..e70faaaa5 100644
--- a/redbox/tests/graph/nodes/test_sends.py
+++ b/redbox/tests/graph/nodes/test_sends.py
@@ -561,6 +561,47 @@ def test_returns_expected_results(
assert result == expected_documents
assert metadata == expected_tool_metadata
+ @pytest.mark.parametrize("expected_tool_result, expected_documents", MCP_TOOL_RESULTS)
+ @patch("redbox.graph.nodes.sends.ClientSession")
+ @patch("redbox.graph.nodes.sends.streamablehttp_client")
+ @patch("redbox.graph.nodes.sends.load_mcp_tools")
+ def test_returns_expected_results_no_args(
+ self,
+ mock_load_tools,
+ mock_http_client,
+ mock_session_class,
+ fake_mcp_tool,
+ expected_tool_result,
+ expected_documents,
+ ):
+ """Test that wrap_async_tool correctly returns results from async tool invocation"""
+ expected_tool_content, expected_tool_metadata = expected_tool_result
+
+ # Mock tool with metadata
+ tool_name = "company_tool"
+ args_schema = {}
+ tool = fake_mcp_tool(tool_name, return_value=expected_tool_content, args_schema=args_schema)
+
+ # mock session with patched mcp setup
+ mock_session = self._patch_mcp_env(mock_load_tools, mock_http_client, mock_session_class, [tool])
+
+ # create the wrapped function
+ wrapped_func = wrap_async_tool(tool, tool_name)
+
+ # rest invocation with sample args
+ test_args = {}
+ result, metadata = wrapped_func(test_args)
+
+ # verify correct interactions
+ mock_http_client.assert_called_once_with(tool.metadata["url"], headers=None)
+ mock_session.initialize.assert_called_once()
+ mock_load_tools.assert_called_once_with(mock_session)
+ tool.ainvoke.assert_called_once_with(test_args)
+
+ # assert the result matches our expected output
+ assert result == expected_documents
+ assert metadata == expected_tool_metadata
+
@patch("redbox.graph.nodes.sends.ClientSession")
@patch("redbox.graph.nodes.sends.streamablehttp_client")
@patch("redbox.graph.nodes.sends.load_mcp_tools", new_callable=AsyncMock)
@@ -576,6 +617,87 @@ def test_tool_not_found(self, mock_load_tools, mock_http_client, mock_session_cl
with pytest.raises(ValueError, match="tool with name 'missing_tool' not found"):
wrapped_func({"foo": "bar"})
+ def test_sso_token_retrieval_failure(self, fake_mcp_tool):
+ """Test that wrap_async_tool raises when sso_access_token.get() fails."""
+ tool = fake_mcp_tool("dummy_tool", return_value=None)
+ tool.metadata["sso_access_token"] = MagicMock()
+ tool.metadata["sso_access_token"].get.side_effect = RuntimeError("vault unavailable")
+
+ wrapped_func = wrap_async_tool(tool, "dummy_tool")
+
+ with pytest.raises(RuntimeError, match="vault unavailable"):
+ wrapped_func({"foo": "bar"})
+
+ @patch("redbox.graph.nodes.sends.ClientSession")
+ @patch("redbox.graph.nodes.sends.streamablehttp_client")
+ @patch("redbox.graph.nodes.sends.load_mcp_tools")
+ def test_intermediate_step_stripped_when_not_in_schema(
+ self,
+ mock_load_tools,
+ mock_http_client,
+ mock_session_class,
+ fake_mcp_tool,
+ ):
+ """Test that is_intermediate_step is removed from args when not required by the tool schema."""
+ return_value = "some content"
+ tool_name = "company_tool"
+ args_schema = {"company_name": {"type": "string"}, "required": ["company_name"]}
+ tool = fake_mcp_tool(tool_name, return_value=return_value, args_schema=args_schema)
+
+ # make ainvoke return something format_mcp_tool_response can handle, or use a non-datahub type
+ tool.metadata["creator_type"] = MagicMock() # non-datahub, returns raw result
+ tool.ainvoke = AsyncMock(return_value=return_value)
+
+ self._patch_mcp_env(mock_load_tools, mock_http_client, mock_session_class, [tool])
+
+ wrapped_func = wrap_async_tool(tool, tool_name)
+
+ test_args = {"company_name": "BMW", "is_intermediate_step": True}
+ wrapped_func(test_args)
+
+ # is_intermediate_step should have been popped before ainvoke was called
+ tool.ainvoke.assert_called_once_with({"company_name": "BMW"})
+
+ @patch("redbox.graph.nodes.sends.ClientSession")
+ @patch("redbox.graph.nodes.sends.streamablehttp_client")
+ @patch("redbox.graph.nodes.sends.load_mcp_tools")
+ def test_intermediate_step_retained_when_in_schema(
+ self,
+ mock_load_tools,
+ mock_http_client,
+ mock_session_class,
+ fake_mcp_tool,
+ ):
+ """Test that is_intermediate_step is kept in args when the tool schema requires it."""
+ return_value = "some content"
+ tool_name = "company_tool"
+ args_schema = {"company_name": {"type": "string"}, "required": ["company_name", "is_intermediate_step"]}
+ tool = fake_mcp_tool(tool_name, return_value=return_value, args_schema=args_schema)
+
+ tool.metadata["creator_type"] = MagicMock()
+ tool.ainvoke = AsyncMock(return_value=return_value)
+
+ self._patch_mcp_env(mock_load_tools, mock_http_client, mock_session_class, [tool])
+
+ wrapped_func = wrap_async_tool(tool, tool_name)
+
+ test_args = {"company_name": "BMW", "is_intermediate_step": True}
+ wrapped_func(test_args)
+
+ # is_intermediate_step should be preserved
+ tool.ainvoke.assert_called_once_with({"company_name": "BMW", "is_intermediate_step": True})
+
+ @patch("redbox.graph.nodes.sends.asyncio")
+ def test_asyncio_run_failure(self, mock_asyncio, fake_mcp_tool):
+ """Test that wrap_async_tool re-raises when asyncio.run itself fails."""
+ tool = fake_mcp_tool("dummy_tool", return_value=None)
+ mock_asyncio.run.side_effect = RuntimeError("event loop closed")
+
+ wrapped_func = wrap_async_tool(tool, "dummy_tool")
+
+ with pytest.raises(RuntimeError, match="event loop closed"):
+ wrapped_func({"foo": "bar"})
+
@pytest.mark.parametrize(
"token_input, expected_output",
diff --git a/redbox/tests/graph/test_patterns.py b/redbox/tests/graph/test_patterns.py
index a5ed48c46..dac6df798 100644
--- a/redbox/tests/graph/test_patterns.py
+++ b/redbox/tests/graph/test_patterns.py
@@ -1017,6 +1017,158 @@ async def test_llm_response_truncation(
assert len(response) == 3 # content, eval task, task status
assert mock_tool_calls.call_count == 1
+ @pytest.mark.parametrize(
+ "test_name, tool_results, expect_feedback, expected_reasons",
+ [
+ (
+ "feedback-required-single-tool",
+ [
+ AIMessage(
+ content=(
+ "some results",
+ "pass",
+ "False",
+ "Multiple company records returned user must clarify which one they want.",
+ )
+ )
+ ],
+ True,
+ ["Multiple company records returned user must clarify which one they want."],
+ ),
+ (
+ "feedback-required-multiple-tools",
+ [
+ AIMessage(content=("results-1", "pass", "False", "Multiple companies found, clarify which.")),
+ AIMessage(content=("results-2", "pass", "False", "Multiple interactions found, clarify which.")),
+ ],
+ True,
+ ["Multiple companies found, clarify which.", "Multiple interactions found, clarify which."],
+ ),
+ (
+ "feedback-required-mixed-tools",
+ [
+ AIMessage(content=("results-1", "pass", "False", "Clarification needed.")),
+ AIMessage(content=("results-2", "pass", "False")),
+ ],
+ True,
+ ["Clarification needed."],
+ ),
+ (
+ "no-feedback-pass",
+ [
+ AIMessage(content=("clean result", "pass", "False")),
+ ],
+ False,
+ [],
+ ),
+ (
+ "no-feedback-fail-no-reason",
+ [
+ AIMessage(content=("failed result", "fail", "False")),
+ ],
+ False,
+ [],
+ ),
+ (
+ "no-feedback-multiple-pass",
+ [
+ AIMessage(content=("result-1", "pass", "False")),
+ AIMessage(content=("result-2", "pass", "False")),
+ ],
+ False,
+ [],
+ ),
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_feedback_required_vs_not(
+ self,
+ test_name,
+ tool_results,
+ expect_feedback,
+ expected_reasons,
+ fake_state,
+ mocker: MockerFixture,
+ mock_datahub_tools,
+ ):
+ """Test agent returns REQUIRES_USER_FEEDBACK when any tool result has a reason, otherwise COMPLETED."""
+ res = AIMessage(
+ content="test",
+ tool_calls=[
+ {"name": "test_tool", "args": {"is_intermediate_step": False}, "id": "fake-id", "type": "tool_call"}
+ ],
+ )
+ llm = GenericFakeChatModel(messages=iter([res] * 10))
+ mocker.patch("redbox.chains.runnables.get_chat_llm", return_value=llm)
+
+ mock_tool_calls = mocker.patch("redbox.graph.nodes.processes.run_tools_parallel")
+ mock_tool_calls.return_value = tool_results
+
+ agent_name = "Internal_Retrieval_Agent"
+ agent_task, multi_agent_plan = configure_agent_task_plan({agent_name: agent_name})
+ tasks = [agent_task(task="Fake Task", expected_output="Fake output")]
+ plan = multi_agent_plan().model_copy(update={"tasks": tasks})
+
+ fake_state.user_feedback = "proceed"
+ fake_state.agent_plans = plan
+ fake_state.tasks_evaluator = ""
+ fake_state.messages = [AIMessage(content=plan.tasks[0].model_dump_json())]
+ fake_state.request.sso_token_getter = lambda: "fake-token"
+
+ fake_agent = build_datahub_agent_with_loop(
+ agent_name=agent_name,
+ system_prompt="Fake prompt",
+ tools=[],
+ use_metadata=False,
+ max_tokens=10000,
+ pre_process=None,
+ loop_condition=lambda: True,
+ max_attempt=2,
+ )
+
+ response = await fake_agent.ainvoke(fake_state)
+
+ assert response is not None
+ assert "agents_results" in response
+ assert "agent_plans" in response
+
+ result_message = response["agents_results"]["task0"]
+
+ if expect_feedback:
+ assert response["agent_plans"].get_task_status("task0") == TaskStatus.REQUIRES_USER_FEEDBACK
+ assert "Ask user for feedback based on failure reason" in result_message.content
+ for reason in expected_reasons:
+ assert reason in result_message.content
+
+ collated_result = ""
+ feedback_reasons = []
+ for i, tr in enumerate(tool_results):
+ if len(tr.content) > 3:
+ reason = tr.content[3]
+ feedback_reasons.append(f"Failure reason: {reason}.\n\n{tr.content[0]}")
+ else:
+ collated_result += f"{tr.content[0]}"
+
+ combined_feedback = "\n\n".join(feedback_reasons)
+
+ assert (
+ result_message.content
+ == f"<{agent_name}_Result>Ask user for feedback based on failure reason. {combined_feedback}\n\n{collated_result}{agent_name}_Result>"
+ )
+ else:
+ assert response["agent_plans"].get_task_status("task0") == TaskStatus.COMPLETED
+ assert "Ask user for feedback based on failure reason" not in result_message.content
+
+ for tr in tool_results:
+ assert tr.content[0] in result_message.content
+
+ collated_result = ""
+ for i, tr in enumerate(tool_results):
+ collated_result += f"{tr.content[0]}"
+ loop_result = " ".join([collated_result, collated_result])
+
+ assert result_message.content == f"<{agent_name}_Result>{loop_result}{agent_name}_Result>"
+
@pytest.mark.parametrize(
"task_idx, task_status, expected",