Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions cascadeflow/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,19 @@ def _get_provider(self, model: ModelConfig):
# Fallback to provider-type lookup (backwards compatibility)
return self.providers[model.provider]

def _provider_supports_tools(self, provider: Any) -> bool:
supports_attr = getattr(provider, "supports_tools", None)
if isinstance(supports_attr, bool):
return supports_attr
if callable(supports_attr):
return bool(supports_attr())
return callable(getattr(provider, "complete_with_tools", None))

def _get_tool_complete_callable(self, provider: Any):
from cascadeflow.providers.base import BaseProvider

return provider.complete if isinstance(provider, BaseProvider) else provider.complete_with_tools

def _normalize_messages(
self, query: str, messages: Optional[list[dict[str, Any]]]
) -> tuple[str, Optional[list[dict[str, Any]]]]:
Expand Down Expand Up @@ -2144,11 +2157,12 @@ async def _execute_direct_with_timing(

direct_start = time.time()
transcript: list[dict[str, Any]] = []
if tools and hasattr(provider, "complete_with_tools"):
if tools and self._provider_supports_tools(provider):
tool_messages = list(messages or [{"role": "user", "content": query}])
tool_complete = self._get_tool_complete_callable(provider)
response = None
for step in range(max_steps):
response = await provider.complete_with_tools(
response = await tool_complete(
messages=tool_messages,
tools=tools,
tool_choice=tool_choice,
Expand Down Expand Up @@ -2332,9 +2346,10 @@ async def _stream_direct_with_timing(
visual.clear()
raise
else:
if tools and hasattr(provider, "complete_with_tools"):
if tools and self._provider_supports_tools(provider):
tool_messages = messages or [{"role": "user", "content": query}]
response = await provider.complete_with_tools(
tool_complete = self._get_tool_complete_callable(provider)
response = await tool_complete(
messages=tool_messages,
tools=tools,
tool_choice=tool_choice,
Expand Down
4 changes: 2 additions & 2 deletions cascadeflow/core/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,7 @@ async def _call_drafter(
# TOOL PATH: Use complete_with_tools() with messages format
tool_messages = messages or [{"role": "user", "content": query}]

result = await provider.complete_with_tools(
result = await provider.complete(
messages=tool_messages,
tools=tools,
tool_choice=tool_choice,
Expand Down Expand Up @@ -1374,7 +1374,7 @@ async def _call_verifier(
# TOOL PATH: Use complete_with_tools() with messages format
tool_messages = messages or [{"role": "user", "content": query}]

result = await provider.complete_with_tools(
result = await provider.complete(
messages=tool_messages,
tools=tools,
tool_choice=tool_choice,
Expand Down
47 changes: 40 additions & 7 deletions cascadeflow/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,12 +542,17 @@ def _check_tool_support(self) -> bool:
Check if provider supports tool calling.

Override to indicate if provider supports tool calling.
Default: False (safe default, providers opt-in)
Default: auto-detect based on provider implementation.

Returns:
True if provider supports tool calling, False otherwise
"""
return False
# Historically, parts of the library detected tool support via
# `hasattr(provider, "complete_with_tools")` and invoked it directly.
# To keep that behavior consistent (while routing tool calls through
# BaseProvider.complete() to benefit from retry/circuit-breaker), we
# default to auto-detecting support when `complete_with_tools` exists.
return callable(getattr(self, "complete_with_tools", None))

# ========================================================================
# RETRY LOGIC METHODS
Expand Down Expand Up @@ -821,7 +826,7 @@ async def _complete_with_tools_impl(
prompt: Optional[str] = None,
model: str = "",
tools: Optional[list[Any]] = None,
tool_choice: str = "auto",
tool_choice: Any = "auto",
max_tokens: int = 4096,
temperature: float = 0.7,
system_prompt: Optional[str] = None,
Expand Down Expand Up @@ -892,9 +897,37 @@ async def _complete_with_tools_impl(
... metadata=metadata
... )
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support tool calling. "
f"Override _complete_with_tools_impl() to add support."
complete_with_tools = getattr(self, "complete_with_tools", None)
if not callable(complete_with_tools):
raise NotImplementedError(
f"{self.__class__.__name__} does not support tool calling. "
"Implement complete_with_tools(...) or override _complete_with_tools_impl()."
)

normalized_messages: list[dict[str, Any]]
if messages:
normalized_messages = list(messages)
elif prompt is not None:
normalized_messages = [{"role": "user", "content": prompt}]
else:
raise ValueError("Tool calling requires either `messages` or `prompt`.")

if system_prompt:
has_system = any(
isinstance(item, dict) and item.get("role") == "system"
for item in normalized_messages
)
if not has_system:
normalized_messages.insert(0, {"role": "system", "content": system_prompt})

return await complete_with_tools(
messages=normalized_messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
tool_choice=tool_choice,
**kwargs,
)

def _get_litellm_prefix(self) -> Optional[str]:
Expand Down Expand Up @@ -1007,7 +1040,7 @@ async def complete(
system_prompt: Optional[str] = None,
messages: Optional[list[dict[str, Any]]] = None,
tools: Optional[list[Any]] = None,
tool_choice: str = "auto",
tool_choice: Any = "auto",
**kwargs,
) -> ModelResponse:
"""
Expand Down
10 changes: 9 additions & 1 deletion cascadeflow/providers/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
from collections.abc import AsyncIterator
from typing import Any, Optional
from urllib.parse import urlparse, urlunparse

import httpx

Expand Down Expand Up @@ -183,7 +184,14 @@ def __init__(
"""
super().__init__(api_key=api_key, retry_config=retry_config, http_config=http_config)

self.base_url = base_url or os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1")
raw_base_url = base_url or os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1")
raw_base_url = raw_base_url.rstrip("/")
parsed = urlparse(raw_base_url)
if parsed.path in ("", "/"):
parsed = parsed._replace(path="/v1")
self.base_url = urlunparse(parsed).rstrip("/")
else:
self.base_url = raw_base_url
self.timeout = timeout

headers = {"Content-Type": "application/json"}
Expand Down
60 changes: 52 additions & 8 deletions cascadeflow/streaming/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,17 @@ async def stream(
logger.info(f"Streaming from draft model: {draft_model.name}")

# Check if provider supports tool calling
if hasattr(draft_provider, "complete_with_tools"):
def _provider_supports_tools(p: Any) -> bool:
supports_attr = getattr(p, "supports_tools", None)
if isinstance(supports_attr, bool):
return supports_attr
if callable(supports_attr):
return bool(supports_attr())
return callable(getattr(p, "complete_with_tools", None))

if _provider_supports_tools(draft_provider):
# Use tool-specific method
if hasattr(draft_provider, "stream_with_tools"):
if hasattr(type(draft_provider), "stream_with_tools"):
# Streaming with tools
logger.info("Using stream_with_tools for progressive tool streaming")

Expand Down Expand Up @@ -539,9 +547,18 @@ async def stream(
logger.info("Using complete_with_tools (non-streaming)")

# 🔧 FIX: Pass messages instead of model/prompt
response = await draft_provider.complete_with_tools(
from cascadeflow.providers.base import BaseProvider

_draft_tool_complete = (
draft_provider.complete
if isinstance(draft_provider, BaseProvider)
else draft_provider.complete_with_tools
)

response = await _draft_tool_complete(
messages=tool_messages, # ✅ FIXED
tools=tools, # ← Explicit
model=draft_model.name,
max_tokens=max_tokens,
temperature=temperature,
tool_choice=tool_choice, # ← Explicit
Expand Down Expand Up @@ -831,9 +848,18 @@ async def stream(
)
draft_input_tokens += self._estimate_messages_tokens(current_messages)

response = await draft_provider.complete_with_tools(
from cascadeflow.providers.base import BaseProvider

_draft_turn_complete = (
draft_provider.complete
if isinstance(draft_provider, BaseProvider)
else draft_provider.complete_with_tools
)

response = await _draft_turn_complete(
messages=current_messages,
tools=tools,
model=draft_model.name,
max_tokens=max_tokens,
temperature=temperature,
tool_choice=tool_choice,
Expand Down Expand Up @@ -986,8 +1012,8 @@ async def stream(

verifier_input_tokens += self._estimate_messages_tokens(tool_messages)

if hasattr(verifier_provider, "complete_with_tools"):
if hasattr(verifier_provider, "stream_with_tools"):
if _provider_supports_tools(verifier_provider):
if hasattr(type(verifier_provider), "stream_with_tools"):
# Streaming verifier
logger.info("Verifier: Using stream_with_tools")

Expand Down Expand Up @@ -1017,9 +1043,18 @@ async def stream(
logger.info("Verifier: Using complete_with_tools (non-streaming)")

# 🔧 FIX: Pass messages instead of model/prompt
response = await verifier_provider.complete_with_tools(
from cascadeflow.providers.base import BaseProvider

_verifier_tool_complete = (
verifier_provider.complete
if isinstance(verifier_provider, BaseProvider)
else verifier_provider.complete_with_tools
)

response = await _verifier_tool_complete(
messages=tool_messages, # ✅ FIXED
tools=tools, # ← Explicit
model=verifier_model.name,
max_tokens=max_tokens,
temperature=temperature,
tool_choice=tool_choice, # ← Explicit
Expand Down Expand Up @@ -1106,9 +1141,18 @@ async def stream(
)
verifier_input_tokens += self._estimate_messages_tokens(current_messages)

response = await verifier_provider.complete_with_tools(
from cascadeflow.providers.base import BaseProvider

_verifier_turn_complete = (
verifier_provider.complete
if isinstance(verifier_provider, BaseProvider)
else verifier_provider.complete_with_tools
)

response = await _verifier_turn_complete(
messages=current_messages,
tools=tools,
model=verifier_model.name,
max_tokens=max_tokens,
temperature=temperature,
tool_choice=tool_choice,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_init_custom_url(self):

def test_init_from_env(self):
"""Test initialization from OLLAMA_HOST env var."""
with patch.dict("os.environ", {"OLLAMA_HOST": "http://remote:11434"}):
with patch.dict("os.environ", {"OLLAMA_HOST": "http://remote:11434"}, clear=True):
provider = OllamaProvider()
assert provider.base_url == "http://remote:11434"

Expand Down
100 changes: 100 additions & 0 deletions tests/test_provider_tool_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import asyncio

from cascadeflow.providers.base import BaseProvider, ModelResponse, RetryConfig


class _FakeToolProvider(BaseProvider):
def __init__(self, *, retry_config: RetryConfig) -> None:
super().__init__(api_key="test", retry_config=retry_config, enable_circuit_breaker=False)
self.complete_with_tools_calls = 0

async def _complete_impl(
self,
prompt: str,
model: str,
max_tokens: int = 4096,
temperature: float = 0.7,
system_prompt: str | None = None,
**kwargs,
) -> ModelResponse:
return ModelResponse(
content="ok",
model=model,
provider="fake",
cost=0.0,
tokens_used=0,
confidence=1.0,
latency_ms=0.0,
tool_calls=None,
metadata={},
)

async def _stream_impl(
self,
prompt: str,
model: str,
max_tokens: int = 4096,
temperature: float = 0.7,
system_prompt: str | None = None,
**kwargs,
):
if False: # pragma: no cover
yield ""

def estimate_cost(self, tokens: int, model: str) -> float:
return 0.0

async def complete_with_tools(
self,
messages: list[dict[str, str]],
tools=None,
model: str = "x",
max_tokens: int = 4096,
temperature: float = 0.7,
tool_choice=None,
**kwargs,
) -> ModelResponse:
self.complete_with_tools_calls += 1
if self.complete_with_tools_calls == 1:
raise RuntimeError("429 Too Many Requests")
return ModelResponse(
content="tool-ok",
model=model,
provider="fake",
cost=0.0,
tokens_used=0,
confidence=1.0,
latency_ms=0.0,
tool_calls=[],
metadata={},
)


def test_tool_calls_use_retry_wrapper() -> None:
async def _run() -> None:
provider = _FakeToolProvider(
retry_config=RetryConfig(
max_attempts=2,
initial_delay=0,
max_delay=0,
jitter=False,
rate_limit_backoff=0,
)
)

response = await provider.complete(
messages=[{"role": "user", "content": "hi"}],
tools=[
{
"type": "function",
"function": {"name": "noop", "parameters": {"type": "object"}},
}
],
model="x",
)

assert response.content == "tool-ok"
assert provider.complete_with_tools_calls == 2
assert provider.retry_metrics.total_attempts == 2

asyncio.run(_run())
2 changes: 1 addition & 1 deletion tests/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_init_custom_url(self):

def test_init_from_env(self):
"""Test initialization from environment variable."""
with patch.dict(os.environ, {"VLLM_BASE_URL": "http://env:8000/v1"}):
with patch.dict(os.environ, {"VLLM_BASE_URL": "http://env:8000/v1"}, clear=True):
provider = VLLMProvider()
assert provider.base_url == "http://env:8000/v1"

Expand Down