diff --git a/cascadeflow/agent.py b/cascadeflow/agent.py index c8faab44..7b60d7c0 100644 --- a/cascadeflow/agent.py +++ b/cascadeflow/agent.py @@ -735,6 +735,19 @@ def _get_provider(self, model: ModelConfig): # Fallback to provider-type lookup (backwards compatibility) return self.providers[model.provider] + def _provider_supports_tools(self, provider: Any) -> bool: + supports_attr = getattr(provider, "supports_tools", None) + if isinstance(supports_attr, bool): + return supports_attr + if callable(supports_attr): + return bool(supports_attr()) + return callable(getattr(provider, "complete_with_tools", None)) + + def _get_tool_complete_callable(self, provider: Any): + from cascadeflow.providers.base import BaseProvider + + return provider.complete if isinstance(provider, BaseProvider) else provider.complete_with_tools + def _normalize_messages( self, query: str, messages: Optional[list[dict[str, Any]]] ) -> tuple[str, Optional[list[dict[str, Any]]]]: @@ -2144,11 +2157,12 @@ async def _execute_direct_with_timing( direct_start = time.time() transcript: list[dict[str, Any]] = [] - if tools and hasattr(provider, "complete_with_tools"): + if tools and self._provider_supports_tools(provider): tool_messages = list(messages or [{"role": "user", "content": query}]) + tool_complete = self._get_tool_complete_callable(provider) response = None for step in range(max_steps): - response = await provider.complete_with_tools( + response = await tool_complete( messages=tool_messages, tools=tools, tool_choice=tool_choice, @@ -2332,9 +2346,10 @@ async def _stream_direct_with_timing( visual.clear() raise else: - if tools and hasattr(provider, "complete_with_tools"): + if tools and self._provider_supports_tools(provider): tool_messages = messages or [{"role": "user", "content": query}] - response = await provider.complete_with_tools( + tool_complete = self._get_tool_complete_callable(provider) + response = await tool_complete( messages=tool_messages, tools=tools, tool_choice=tool_choice, diff --git a/cascadeflow/core/cascade.py b/cascadeflow/core/cascade.py index 2efcc3af..4144c2de 100644 --- a/cascadeflow/core/cascade.py +++ b/cascadeflow/core/cascade.py @@ -1327,7 +1327,7 @@ async def _call_drafter( # TOOL PATH: Use complete_with_tools() with messages format tool_messages = messages or [{"role": "user", "content": query}] - result = await provider.complete_with_tools( + result = await provider.complete( messages=tool_messages, tools=tools, tool_choice=tool_choice, @@ -1374,7 +1374,7 @@ async def _call_verifier( # TOOL PATH: Use complete_with_tools() with messages format tool_messages = messages or [{"role": "user", "content": query}] - result = await provider.complete_with_tools( + result = await provider.complete( messages=tool_messages, tools=tools, tool_choice=tool_choice, diff --git a/cascadeflow/providers/base.py b/cascadeflow/providers/base.py index 55afa024..f133db14 100644 --- a/cascadeflow/providers/base.py +++ b/cascadeflow/providers/base.py @@ -542,12 +542,17 @@ def _check_tool_support(self) -> bool: Check if provider supports tool calling. Override to indicate if provider supports tool calling. - Default: False (safe default, providers opt-in) + Default: auto-detect based on provider implementation. Returns: True if provider supports tool calling, False otherwise """ - return False + # Historically, parts of the library detected tool support via + # `hasattr(provider, "complete_with_tools")` and invoked it directly. + # To keep that behavior consistent (while routing tool calls through + # BaseProvider.complete() to benefit from retry/circuit-breaker), we + # default to auto-detecting support when `complete_with_tools` exists. + return callable(getattr(self, "complete_with_tools", None)) # ======================================================================== # RETRY LOGIC METHODS @@ -821,7 +826,7 @@ async def _complete_with_tools_impl( prompt: Optional[str] = None, model: str = "", tools: Optional[list[Any]] = None, - tool_choice: str = "auto", + tool_choice: Any = "auto", max_tokens: int = 4096, temperature: float = 0.7, system_prompt: Optional[str] = None, @@ -892,9 +897,37 @@ async def _complete_with_tools_impl( ... metadata=metadata ... ) """ - raise NotImplementedError( - f"{self.__class__.__name__} does not support tool calling. " - f"Override _complete_with_tools_impl() to add support." + complete_with_tools = getattr(self, "complete_with_tools", None) + if not callable(complete_with_tools): + raise NotImplementedError( + f"{self.__class__.__name__} does not support tool calling. " + "Implement complete_with_tools(...) or override _complete_with_tools_impl()." + ) + + normalized_messages: list[dict[str, Any]] + if messages: + normalized_messages = list(messages) + elif prompt is not None: + normalized_messages = [{"role": "user", "content": prompt}] + else: + raise ValueError("Tool calling requires either `messages` or `prompt`.") + + if system_prompt: + has_system = any( + isinstance(item, dict) and item.get("role") == "system" + for item in normalized_messages + ) + if not has_system: + normalized_messages.insert(0, {"role": "system", "content": system_prompt}) + + return await complete_with_tools( + messages=normalized_messages, + tools=tools, + model=model, + max_tokens=max_tokens, + temperature=temperature, + tool_choice=tool_choice, + **kwargs, ) def _get_litellm_prefix(self) -> Optional[str]: @@ -1007,7 +1040,7 @@ async def complete( system_prompt: Optional[str] = None, messages: Optional[list[dict[str, Any]]] = None, tools: Optional[list[Any]] = None, - tool_choice: str = "auto", + tool_choice: Any = "auto", **kwargs, ) -> ModelResponse: """ diff --git a/cascadeflow/providers/vllm.py b/cascadeflow/providers/vllm.py index ab79d28a..08bad095 100644 --- a/cascadeflow/providers/vllm.py +++ b/cascadeflow/providers/vllm.py @@ -5,6 +5,7 @@ import time from collections.abc import AsyncIterator from typing import Any, Optional +from urllib.parse import urlparse, urlunparse import httpx @@ -183,7 +184,14 @@ def __init__( """ super().__init__(api_key=api_key, retry_config=retry_config, http_config=http_config) - self.base_url = base_url or os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1") + raw_base_url = base_url or os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1") + raw_base_url = raw_base_url.rstrip("/") + parsed = urlparse(raw_base_url) + if parsed.path in ("", "/"): + parsed = parsed._replace(path="/v1") + self.base_url = urlunparse(parsed).rstrip("/") + else: + self.base_url = raw_base_url self.timeout = timeout headers = {"Content-Type": "application/json"} diff --git a/cascadeflow/streaming/tools.py b/cascadeflow/streaming/tools.py index f185b4f2..a04b6a5b 100644 --- a/cascadeflow/streaming/tools.py +++ b/cascadeflow/streaming/tools.py @@ -502,9 +502,17 @@ async def stream( logger.info(f"Streaming from draft model: {draft_model.name}") # Check if provider supports tool calling - if hasattr(draft_provider, "complete_with_tools"): + def _provider_supports_tools(p: Any) -> bool: + supports_attr = getattr(p, "supports_tools", None) + if isinstance(supports_attr, bool): + return supports_attr + if callable(supports_attr): + return bool(supports_attr()) + return callable(getattr(p, "complete_with_tools", None)) + + if _provider_supports_tools(draft_provider): # Use tool-specific method - if hasattr(draft_provider, "stream_with_tools"): + if hasattr(type(draft_provider), "stream_with_tools"): # Streaming with tools logger.info("Using stream_with_tools for progressive tool streaming") @@ -539,9 +547,18 @@ async def stream( logger.info("Using complete_with_tools (non-streaming)") # 🔧 FIX: Pass messages instead of model/prompt - response = await draft_provider.complete_with_tools( + from cascadeflow.providers.base import BaseProvider + + _draft_tool_complete = ( + draft_provider.complete + if isinstance(draft_provider, BaseProvider) + else draft_provider.complete_with_tools + ) + + response = await _draft_tool_complete( messages=tool_messages, # ✅ FIXED tools=tools, # ← Explicit + model=draft_model.name, max_tokens=max_tokens, temperature=temperature, tool_choice=tool_choice, # ← Explicit @@ -831,9 +848,18 @@ async def stream( ) draft_input_tokens += self._estimate_messages_tokens(current_messages) - response = await draft_provider.complete_with_tools( + from cascadeflow.providers.base import BaseProvider + + _draft_turn_complete = ( + draft_provider.complete + if isinstance(draft_provider, BaseProvider) + else draft_provider.complete_with_tools + ) + + response = await _draft_turn_complete( messages=current_messages, tools=tools, + model=draft_model.name, max_tokens=max_tokens, temperature=temperature, tool_choice=tool_choice, @@ -986,8 +1012,8 @@ async def stream( verifier_input_tokens += self._estimate_messages_tokens(tool_messages) - if hasattr(verifier_provider, "complete_with_tools"): - if hasattr(verifier_provider, "stream_with_tools"): + if _provider_supports_tools(verifier_provider): + if hasattr(type(verifier_provider), "stream_with_tools"): # Streaming verifier logger.info("Verifier: Using stream_with_tools") @@ -1017,9 +1043,18 @@ async def stream( logger.info("Verifier: Using complete_with_tools (non-streaming)") # 🔧 FIX: Pass messages instead of model/prompt - response = await verifier_provider.complete_with_tools( + from cascadeflow.providers.base import BaseProvider + + _verifier_tool_complete = ( + verifier_provider.complete + if isinstance(verifier_provider, BaseProvider) + else verifier_provider.complete_with_tools + ) + + response = await _verifier_tool_complete( messages=tool_messages, # ✅ FIXED tools=tools, # ← Explicit + model=verifier_model.name, max_tokens=max_tokens, temperature=temperature, tool_choice=tool_choice, # ← Explicit @@ -1106,9 +1141,18 @@ async def stream( ) verifier_input_tokens += self._estimate_messages_tokens(current_messages) - response = await verifier_provider.complete_with_tools( + from cascadeflow.providers.base import BaseProvider + + _verifier_turn_complete = ( + verifier_provider.complete + if isinstance(verifier_provider, BaseProvider) + else verifier_provider.complete_with_tools + ) + + response = await _verifier_turn_complete( messages=current_messages, tools=tools, + model=verifier_model.name, max_tokens=max_tokens, temperature=temperature, tool_choice=tool_choice, diff --git a/tests/test_ollama.py b/tests/test_ollama.py index d71e2dfb..467e97dd 100644 --- a/tests/test_ollama.py +++ b/tests/test_ollama.py @@ -56,7 +56,7 @@ def test_init_custom_url(self): def test_init_from_env(self): """Test initialization from OLLAMA_HOST env var.""" - with patch.dict("os.environ", {"OLLAMA_HOST": "http://remote:11434"}): + with patch.dict("os.environ", {"OLLAMA_HOST": "http://remote:11434"}, clear=True): provider = OllamaProvider() assert provider.base_url == "http://remote:11434" diff --git a/tests/test_provider_tool_retry.py b/tests/test_provider_tool_retry.py new file mode 100644 index 00000000..23b1107d --- /dev/null +++ b/tests/test_provider_tool_retry.py @@ -0,0 +1,100 @@ +import asyncio + +from cascadeflow.providers.base import BaseProvider, ModelResponse, RetryConfig + + +class _FakeToolProvider(BaseProvider): + def __init__(self, *, retry_config: RetryConfig) -> None: + super().__init__(api_key="test", retry_config=retry_config, enable_circuit_breaker=False) + self.complete_with_tools_calls = 0 + + async def _complete_impl( + self, + prompt: str, + model: str, + max_tokens: int = 4096, + temperature: float = 0.7, + system_prompt: str | None = None, + **kwargs, + ) -> ModelResponse: + return ModelResponse( + content="ok", + model=model, + provider="fake", + cost=0.0, + tokens_used=0, + confidence=1.0, + latency_ms=0.0, + tool_calls=None, + metadata={}, + ) + + async def _stream_impl( + self, + prompt: str, + model: str, + max_tokens: int = 4096, + temperature: float = 0.7, + system_prompt: str | None = None, + **kwargs, + ): + if False: # pragma: no cover + yield "" + + def estimate_cost(self, tokens: int, model: str) -> float: + return 0.0 + + async def complete_with_tools( + self, + messages: list[dict[str, str]], + tools=None, + model: str = "x", + max_tokens: int = 4096, + temperature: float = 0.7, + tool_choice=None, + **kwargs, + ) -> ModelResponse: + self.complete_with_tools_calls += 1 + if self.complete_with_tools_calls == 1: + raise RuntimeError("429 Too Many Requests") + return ModelResponse( + content="tool-ok", + model=model, + provider="fake", + cost=0.0, + tokens_used=0, + confidence=1.0, + latency_ms=0.0, + tool_calls=[], + metadata={}, + ) + + +def test_tool_calls_use_retry_wrapper() -> None: + async def _run() -> None: + provider = _FakeToolProvider( + retry_config=RetryConfig( + max_attempts=2, + initial_delay=0, + max_delay=0, + jitter=False, + rate_limit_backoff=0, + ) + ) + + response = await provider.complete( + messages=[{"role": "user", "content": "hi"}], + tools=[ + { + "type": "function", + "function": {"name": "noop", "parameters": {"type": "object"}}, + } + ], + model="x", + ) + + assert response.content == "tool-ok" + assert provider.complete_with_tools_calls == 2 + assert provider.retry_metrics.total_attempts == 2 + + asyncio.run(_run()) diff --git a/tests/test_vllm.py b/tests/test_vllm.py index 32fad8c6..c1b968bf 100644 --- a/tests/test_vllm.py +++ b/tests/test_vllm.py @@ -42,7 +42,7 @@ def test_init_custom_url(self): def test_init_from_env(self): """Test initialization from environment variable.""" - with patch.dict(os.environ, {"VLLM_BASE_URL": "http://env:8000/v1"}): + with patch.dict(os.environ, {"VLLM_BASE_URL": "http://env:8000/v1"}, clear=True): provider = VLLMProvider() assert provider.base_url == "http://env:8000/v1"