Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@
"titleBar.activeBackground": "#911212",
"titleBar.activeForeground": "#fbfbec"
},
"terminal.integrated.scrollback": 100000
"terminal.integrated.scrollback": 100000,
"python-envs.defaultEnvManager": "ms-python.python:poetry",
"python-envs.defaultPackageManager": "ms-python.python:poetry"
}
23 changes: 23 additions & 0 deletions redbox/redbox/graph/nodes/runner/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
class BaseToolRunnerException(Exception):
def __init__(self, message: str) -> None:
super().__init__(message)


class ToolNotFoundError(BaseToolRunnerException):
pass


class ToolValidationError(BaseToolRunnerException):
pass


class ToolExecutionError(BaseToolRunnerException):
pass


class ToolTimeoutError(BaseToolRunnerException):
pass


class ToolRegistryError(BaseToolRunnerException):
pass
192 changes: 192 additions & 0 deletions redbox/redbox/graph/nodes/runner/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import logging
from uuid import uuid4
from typing import Optional
from concurrent.futures import ThreadPoolExecutor, TimeoutError, as_completed, Future

from langchain_core.messages import AIMessage, ToolCall
from langchain.tools import StructuredTool

from redbox.models.chain import RedboxState
from redbox.api.format import MCPResponseMetadata
from redbox.graph.nodes.runner import exceptions as tool_exceptions
from redbox.graph.nodes.runner.wrap_async import wrap_async_tool

log = logging.getLogger(__name__)


class ToolRunner:
"""Encapsulates the logic for submitting and parsing individual tool futures."""

def __init__(
self,
tools: list[StructuredTool],
state: RedboxState,
max_workers: int,
is_loop: bool,
parallel_timeout: float,
):
self.tools = tools
self.state = state
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.is_loop = is_loop
self.parallel_timeout = parallel_timeout
self.log_stub = f"[run_tools_parallel run_id='{str(uuid4())[:8]}']"

def run(self, tool_calls: list[ToolCall]) -> list[AIMessage] | None:
"""Submit all tool calls, collect results, and return aggregated responses or None on total failure."""
futures = self._submit_all(tool_calls=tool_calls)
return self._collect(futures=futures)

def _submit_all(self, tool_calls: list[ToolCall]) -> dict[Future, dict]:
"""Submit every tool call to the executor, skipping and logging any that fail to launch."""
futures = {}

for tool_call in tool_calls:
tool_name = tool_call.get("name")
try:
res = self.submit(tool_call=tool_call)
if res is None:
continue
future, metadata = res
futures[future] = metadata

except tool_exceptions.ToolNotFoundError as e:
log.warning(f"{self.log_stub} Tool '{tool_name}' not found: {e}")

except tool_exceptions.ToolValidationError as e:
log.warning(f"{self.log_stub} Tool '{tool_name}' validation error: {e}")

except tool_exceptions.BaseToolRunnerException as e:
log.warning(f"{self.log_stub} Tool '{tool_name}' error: {e}")

return futures

def _collect(self, futures: dict[Future, dict]) -> list[AIMessage] | None:
"""Wait for all futures, parse results, and return responses or None if everything failed."""
responses = []
failed_tools: list[str] = []

for future in as_completed(futures.keys(), timeout=self.parallel_timeout):
future_tool_name = futures[future]["name"]
try:
response = self.parse(future=future, metadata=futures[future])
if response is not None:
responses.append(response)

except tool_exceptions.ToolTimeoutError as e:
log.warning(f"{self.log_stub} Tool '{future_tool_name}' timed out: {e}")
failed_tools.append(future_tool_name)

except tool_exceptions.ToolValidationError as e:
log.warning(f"{self.log_stub} Tool '{future_tool_name}' validation error: {e}")
failed_tools.append(future_tool_name)

except tool_exceptions.ToolExecutionError as e:
log.warning(f"{self.log_stub} Tool '{future_tool_name}' execution error: {e}")
failed_tools.append(future_tool_name)

except tool_exceptions.BaseToolRunnerException as e:
log.warning(f"{self.log_stub} Tool '{future_tool_name}' error: {e}")
failed_tools.append(future_tool_name)

if failed_tools:
log.warning(f"{self.log_stub} {len(failed_tools)} tool(s) failed: {', '.join(failed_tools)}")

if not responses:
log.warning(
f"{self.log_stub} Every tool execution has failed or timed out. "
f"Failed tools: {', '.join(failed_tools) or 'unknown'}."
)
return None

log.warning(
f"{self.log_stub} Completed. Successful: {len(responses)}, "
f"Failed: {len(failed_tools)}. Responses: {responses}"
)
return responses

def submit(self, tool_call: ToolCall) -> tuple[Future, dict] | None:
"""Find, validate, and submit a tool call to the executor. Returns (future, metadata) or None."""
tool_name = tool_call.get("name")
selected_tool: Optional[StructuredTool] = next((tool for tool in self.tools if tool.name == tool_name), None)

if selected_tool is None:
available = [tool.name for tool in self.tools]
raise tool_exceptions.ToolNotFoundError(
f"Tool '{tool_name}' not found. Available tools: {', '.join(available)}"
)

raw_args = tool_call.get("args", {})
if not isinstance(raw_args, dict):
raise tool_exceptions.ToolValidationError(
f"Invalid input for tool '{tool_name}': expected dict, got {type(raw_args).__name__!r}"
)

is_intermediate_step = "False"

try:
if selected_tool.func and not selected_tool.coroutine:
args = {**raw_args, "state": self.state}
future = self.executor.submit(selected_tool.invoke, args)
else:
args = {**raw_args}
if self.is_loop:
is_intermediate_step = args.get("is_intermediate_step", "False")
log.warning(f"intermediate step: {is_intermediate_step}")
future = self.executor.submit(wrap_async_tool(selected_tool, tool_name), args)
except Exception as e:
raise tool_exceptions.ToolExecutionError(
f"Failed to submit tool '{tool_name}' for execution: {str(e)}"
) from e

return future, {"name": tool_name, "intermediate_step": is_intermediate_step}

def parse(self, future: Future, metadata: dict) -> Optional[AIMessage]:
"""Resolve a completed future and transform its result into an AIMessage."""
future_tool_name = metadata["name"]
is_intermediate_step = metadata["intermediate_step"]

try:
response = future.result()
except TimeoutError as e:
raise tool_exceptions.ToolTimeoutError(
f"Tool '{future_tool_name}' timed out after {self.parallel_timeout:.1f}s"
) from e
except Exception as e:
raise tool_exceptions.ToolExecutionError(f"Tool '{future_tool_name}' failed: {str(e)}") from e

log.warning(f"{self.log_stub} This is what I got from tool '{future_tool_name}': {response}")

if response is None:
raise tool_exceptions.ToolExecutionError(
f"Tool '{future_tool_name}' returned None — may have failed or timed out"
)

log.warning(f"{self.log_stub} {future_tool_name} response not None")

result = response
if not self.is_loop:
if isinstance(response, tuple) and isinstance(response[1], MCPResponseMetadata):
result = response[0]

else:
if isinstance(response, tuple) and isinstance(response[1], MCPResponseMetadata):
res = response[0]
metadata: MCPResponseMetadata = response[1]
status = "pass" if res != "" else "fail"
result = (
(
res,
status,
is_intermediate_step,
metadata.user_feedback.reason or "Requires feedback from the user.",
)
if metadata.user_feedback.required
else (res, status, is_intermediate_step)
)

raw_res = result[0] if isinstance(result, tuple) else result
if not raw_res or not isinstance(raw_res, str) or not raw_res.strip():
raise tool_exceptions.ToolValidationError(f"empty or whitespace-only response body: {repr(raw_res)}")

return AIMessage(result)
101 changes: 101 additions & 0 deletions redbox/redbox/graph/nodes/runner/wrap_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import logging
import asyncio
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from langchain_mcp_adapters.tools import load_mcp_tools

# from redbox.graph.nodes.sends import _get_mcp_headers
from redbox.models.file import ChunkCreatorType
from redbox.api.format import format_mcp_tool_response


log = logging.getLogger(__name__)


def _get_mcp_headers(sso_access_token: str | None = None) -> dict[str, str]:
if not sso_access_token:
return {}
token = sso_access_token.strip()
if not token:
return {}
if token.lower().startswith("bearer "):
return {"Authorization": token}
return {"Authorization": f"Bearer {token}"}


def wrap_async_tool(tool, tool_name):
"""
Returns a synchronous function that properly wraps an async tool

Args:
tool_name: The name of the tool to invoke

Returns:
A function that synchronously executes the async tool
"""

def wrapper(args):
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

# get mcp tool url
mcp_url = tool.metadata["url"]
creator_type = tool.metadata["creator_type"]
sso_access_token = tool.metadata["sso_access_token"].get()
headers = _get_mcp_headers(sso_access_token)

try:
# Define the async operation
async def run_tool():
# tool need to be executed within the connection context manager
async with streamablehttp_client(mcp_url, headers=headers or None) as (
read,
write,
_,
):
async with ClientSession(read, write) as session:
# Initialize the connection
init_result = await session.initialize()
server_name = init_result.serverInfo.name
server_version = init_result.serverInfo.version

log.info(f"Calling tool '{tool_name}' on MCP server {server_name}@{server_version}")

# Get tools
tools = await load_mcp_tools(session)

selected_tool = next((t for t in tools if t.name == tool_name), None)
if not selected_tool:
raise ValueError(f"tool with name '{tool_name}' not found")

# remove intermediate step argument if it is not required by tool
if "is_intermediate_step" not in selected_tool.args_schema["required"] and args.get(
"is_intermediate_step"
):
args.pop("is_intermediate_step")
log.warning(f"updated args: {args}")

log.warning(f"tool found with name '{tool_name}'")
log.warning(f"args '{args}'")
result = await selected_tool.ainvoke(args)

log.warning(f"MCP Tool '{tool_name}' result: {result}")

if creator_type == ChunkCreatorType.datahub:
log.warning(f"Formatting MCP tool response for creator_type='{creator_type}'")
return format_mcp_tool_response(
tool_response=result,
creator_type=creator_type,
)

log.warning(f"Returning raw MCP tool response for creator_type='{creator_type}'")
return result

# Run the async function and return its result
return loop.run_until_complete(run_tool())
finally:
# Clean up resources
loop.close()

return wrapper
Loading
Loading