Skip to content
4 changes: 2 additions & 2 deletions django_app/tests/test_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def test_chat_consumer_with_new_session(

communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = alice
connected, _ = await communicator.connect()
connected, _ = await communicator.connect(timeout=5)
assert connected
with patch("redbox_app.redbox_core.consumers.ChatConsumer.redbox.graph", new=mocked_connect):
await communicator.send_json_to({"message": "Hello Hal."})
Expand Down Expand Up @@ -908,7 +908,7 @@ async def test_connect_with_agents_update_via_db(agents_list: list, alice: User)
mock_get.return_value = agents_list
communicator = WebsocketCommunicator(ChatConsumer.as_asgi(), "/ws/chat/")
communicator.scope["user"] = alice
await communicator.connect()
await communicator.connect(timeout=5)
assert mock_get.call_count == 1

assert "Fake_Agent" not in list(ChatConsumer.redbox.agent_configs.keys())
Expand Down
6 changes: 6 additions & 0 deletions redbox/redbox/api/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def format_mcp_tool_response(tool_response, creator_type: ChunkCreatorType) -> t
deep_links = []
match result_type:
case "nullable":
if isinstance(result, str):
return result, metadata

deep_links = [(result.get("url"), result)]
case "paged":
deep_links = [(p.get("url"), p) for p in result.get("items", [])]
Expand Down Expand Up @@ -99,4 +102,7 @@ def format_mcp_tool_response(tool_response, creator_type: ChunkCreatorType) -> t
else:
response.append(json.dumps(item))

if not response:
return ("No results found.", metadata)

return ("\n\n".join(response), metadata)
50 changes: 31 additions & 19 deletions redbox/redbox/graph/nodes/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,37 +699,49 @@ def local_loop_condition():
result = "Tool error: no results received."

elif has_loop and len(ai_msg.tool_calls) > 0: # if loop, we need to transform results
result = result[-1].content # this is a tuple
# format of result: (result, success, is_intermediate_step)
log.warning("my-overall-result")
log.warning(result)
result_content = result[0]
success = result[1]
is_intermediate_step = eval(result[2])
collated_result = ""
feedback_reasons = []
for i, r in enumerate(result):
current_result = r.content # this is a tuple
# format of result: (result, success, is_intermediate_step)
log.debug("my-overall-result")
log.debug(current_result)
result_content = current_result[0]
success = current_result[1]
is_intermediate_step = eval(current_result[2])

if len(current_result) > 3:
reason = current_result[3]
feedback_reasons.append(f"Failure reason: {reason}.\n\n{result_content}")
else:
collated_result += f"<tool_result_{i}>{result_content}</tool_result_{i}>"

if success == "fail":
Comment thread
hwixley marked this conversation as resolved.
# pass error back if any
additional_variables.update({"previous_tool_error": result_content})
else:
# if success tool invocation, and intermediate steps then pass info back
if is_intermediate_step:
additional_variables.update(
{"previous_tool_error": "", "previous_tool_results": all_results}
)

if len(result) > 3:
reason = result[3]
if feedback_reasons:
combined_feedback = "\n\n".join(feedback_reasons)
return {
"agents_results": {
task.id: AIMessage(
content=f"<{agent_name}_Result>Ask user for feedback based on failure reason. Failure reason: {reason}.\n\n{result_content}</{agent_name}_Result>",
content=f"<{agent_name}_Result>Ask user for feedback based on failure reason. {combined_feedback}\n\n{collated_result}</{agent_name}_Result>",
kwargs={
"reason": reason,
"reason": combined_feedback,
},
)
},
"tasks_evaluator": task.task + "\n" + task.expected_output,
"agent_plans": state.agent_plans.update_task_status(task.id, TaskStatus.REQUIRES_USER_FEEDBACK),
}

if success == "fail":
# pass error back if any
additional_variables.update({"previous_tool_error": result_content})
else:
# if success tool invocation, and intermediate steps then pass info back
if is_intermediate_step:
additional_variables.update({"previous_tool_error": "", "previous_tool_results": all_results})
result = result_content
result = collated_result

if isinstance(result, str):
log.warning(f"{log_stub} Using raw string result.")
Expand Down
32 changes: 17 additions & 15 deletions redbox/redbox/graph/nodes/sends.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,24 +127,23 @@ def wrap_async_tool(tool, tool_name):
"""

def wrapper(args):
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

# get mcp tool url
mcp_url = tool.metadata["url"]
creator_type = tool.metadata["creator_type"]
sso_access_token = tool.metadata["sso_access_token"].get()

try:
sso_access_token = tool.metadata["sso_access_token"].get()
except Exception as e:
log.error(f"wrap_async_tool - Failed to retrieve sso_access_token: {e}")
raise

if not sso_access_token:
log.error("wrap_async_tool - MCP sso_access_token is None")

headers = _get_mcp_headers(sso_access_token)

try:
# Define the async operation
async def run_tool():
# tool need to be executed within the connection context manager
async def run_tool():
try:
async with streamablehttp_client(mcp_url, headers=headers or None) as (
read,
write,
Expand All @@ -168,7 +167,7 @@ async def run_tool():
raise ValueError(f"tool with name '{tool_name}' not found")

# remove intermediate step argument if it is not required by tool
if "is_intermediate_step" not in selected_tool.args_schema["required"] and args.get(
if "is_intermediate_step" not in selected_tool.args_schema.get("required", []) and args.get(
"is_intermediate_step"
):
args.pop("is_intermediate_step")
Expand All @@ -193,12 +192,15 @@ async def run_tool():
f"wrap_async_tool - Returning raw MCP tool response for creator_type='{creator_type}'"
)
return result
except Exception as e:
log.error(f"wrap_async_tool - Failed to connect to MCP server at '{mcp_url}': {e}")
raise

# Run the async function and return its result
return loop.run_until_complete(run_tool())
finally:
# Clean up resources
loop.close()
try:
return asyncio.run(run_tool())
except Exception as e:
log.error(f"wrap_async_tool - Unhandled error running tool '{tool_name}': {e}", exc_info=True)
raise

return wrapper

Expand Down
122 changes: 122 additions & 0 deletions redbox/tests/graph/nodes/test_sends.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,47 @@ def test_returns_expected_results(
assert result == expected_documents
assert metadata == expected_tool_metadata

@pytest.mark.parametrize("expected_tool_result, expected_documents", MCP_TOOL_RESULTS)
@patch("redbox.graph.nodes.sends.ClientSession")
@patch("redbox.graph.nodes.sends.streamablehttp_client")
@patch("redbox.graph.nodes.sends.load_mcp_tools")
def test_returns_expected_results_no_args(
self,
mock_load_tools,
mock_http_client,
mock_session_class,
fake_mcp_tool,
expected_tool_result,
expected_documents,
):
"""Test that wrap_async_tool correctly returns results from async tool invocation"""
expected_tool_content, expected_tool_metadata = expected_tool_result

# Mock tool with metadata
tool_name = "company_tool"
args_schema = {}
tool = fake_mcp_tool(tool_name, return_value=expected_tool_content, args_schema=args_schema)

# mock session with patched mcp setup
mock_session = self._patch_mcp_env(mock_load_tools, mock_http_client, mock_session_class, [tool])

# create the wrapped function
wrapped_func = wrap_async_tool(tool, tool_name)

# rest invocation with sample args
test_args = {}
result, metadata = wrapped_func(test_args)

# verify correct interactions
mock_http_client.assert_called_once_with(tool.metadata["url"], headers=None)
mock_session.initialize.assert_called_once()
mock_load_tools.assert_called_once_with(mock_session)
tool.ainvoke.assert_called_once_with(test_args)

# assert the result matches our expected output
assert result == expected_documents
assert metadata == expected_tool_metadata

@patch("redbox.graph.nodes.sends.ClientSession")
@patch("redbox.graph.nodes.sends.streamablehttp_client")
@patch("redbox.graph.nodes.sends.load_mcp_tools", new_callable=AsyncMock)
Expand All @@ -576,6 +617,87 @@ def test_tool_not_found(self, mock_load_tools, mock_http_client, mock_session_cl
with pytest.raises(ValueError, match="tool with name 'missing_tool' not found"):
wrapped_func({"foo": "bar"})

def test_sso_token_retrieval_failure(self, fake_mcp_tool):
"""Test that wrap_async_tool raises when sso_access_token.get() fails."""
tool = fake_mcp_tool("dummy_tool", return_value=None)
tool.metadata["sso_access_token"] = MagicMock()
tool.metadata["sso_access_token"].get.side_effect = RuntimeError("vault unavailable")

wrapped_func = wrap_async_tool(tool, "dummy_tool")

with pytest.raises(RuntimeError, match="vault unavailable"):
wrapped_func({"foo": "bar"})

@patch("redbox.graph.nodes.sends.ClientSession")
@patch("redbox.graph.nodes.sends.streamablehttp_client")
@patch("redbox.graph.nodes.sends.load_mcp_tools")
def test_intermediate_step_stripped_when_not_in_schema(
self,
mock_load_tools,
mock_http_client,
mock_session_class,
fake_mcp_tool,
):
"""Test that is_intermediate_step is removed from args when not required by the tool schema."""
return_value = "some content"
tool_name = "company_tool"
args_schema = {"company_name": {"type": "string"}, "required": ["company_name"]}
tool = fake_mcp_tool(tool_name, return_value=return_value, args_schema=args_schema)

# make ainvoke return something format_mcp_tool_response can handle, or use a non-datahub type
tool.metadata["creator_type"] = MagicMock() # non-datahub, returns raw result
tool.ainvoke = AsyncMock(return_value=return_value)

self._patch_mcp_env(mock_load_tools, mock_http_client, mock_session_class, [tool])

wrapped_func = wrap_async_tool(tool, tool_name)

test_args = {"company_name": "BMW", "is_intermediate_step": True}
wrapped_func(test_args)

# is_intermediate_step should have been popped before ainvoke was called
tool.ainvoke.assert_called_once_with({"company_name": "BMW"})

@patch("redbox.graph.nodes.sends.ClientSession")
@patch("redbox.graph.nodes.sends.streamablehttp_client")
@patch("redbox.graph.nodes.sends.load_mcp_tools")
def test_intermediate_step_retained_when_in_schema(
self,
mock_load_tools,
mock_http_client,
mock_session_class,
fake_mcp_tool,
):
"""Test that is_intermediate_step is kept in args when the tool schema requires it."""
return_value = "some content"
tool_name = "company_tool"
args_schema = {"company_name": {"type": "string"}, "required": ["company_name", "is_intermediate_step"]}
tool = fake_mcp_tool(tool_name, return_value=return_value, args_schema=args_schema)

tool.metadata["creator_type"] = MagicMock()
tool.ainvoke = AsyncMock(return_value=return_value)

self._patch_mcp_env(mock_load_tools, mock_http_client, mock_session_class, [tool])

wrapped_func = wrap_async_tool(tool, tool_name)

test_args = {"company_name": "BMW", "is_intermediate_step": True}
wrapped_func(test_args)

# is_intermediate_step should be preserved
tool.ainvoke.assert_called_once_with({"company_name": "BMW", "is_intermediate_step": True})

@patch("redbox.graph.nodes.sends.asyncio")
def test_asyncio_run_failure(self, mock_asyncio, fake_mcp_tool):
"""Test that wrap_async_tool re-raises when asyncio.run itself fails."""
tool = fake_mcp_tool("dummy_tool", return_value=None)
mock_asyncio.run.side_effect = RuntimeError("event loop closed")

wrapped_func = wrap_async_tool(tool, "dummy_tool")

with pytest.raises(RuntimeError, match="event loop closed"):
wrapped_func({"foo": "bar"})


@pytest.mark.parametrize(
"token_input, expected_output",
Expand Down
Loading
Loading