diff --git a/packages/prime-sandboxes/src/prime_sandboxes/core/client.py b/packages/prime-sandboxes/src/prime_sandboxes/core/client.py index 79ce88401..cdf6e84e3 100644 --- a/packages/prime-sandboxes/src/prime_sandboxes/core/client.py +++ b/packages/prime-sandboxes/src/prime_sandboxes/core/client.py @@ -102,6 +102,15 @@ def _check_auth_required(self) -> None: "No API key configured. Set PRIME_API_KEY environment variable.", ) + def _base_url_for_endpoint(self, endpoint: str) -> str: + normalized_endpoint = endpoint if endpoint.startswith("/") else f"/{endpoint}" + is_sandbox_endpoint = normalized_endpoint == "/sandbox" or normalized_endpoint.startswith( + "/sandbox/" + ) + if is_sandbox_endpoint and self.config.sandbox_base_url: + return self.config.sandbox_base_url + return self.base_url + @retry( retry=retry_if_exception(_is_idempotent_request_retryable_error), stop=stop_after_attempt(3), @@ -169,12 +178,13 @@ def request( """Make a request to the API""" self._check_auth_required() + base_url = self._base_url_for_endpoint(endpoint) if not endpoint.startswith("/"): endpoint = f"/api/v1/{endpoint}" else: endpoint = f"/api/v1{endpoint}" - url = f"{self.base_url}{endpoint}" + url = f"{base_url}{endpoint}" try: method_upper = method.upper() @@ -263,6 +273,15 @@ def _check_auth_required(self) -> None: "No API key configured. Set PRIME_API_KEY environment variable.", ) + def _base_url_for_endpoint(self, endpoint: str) -> str: + normalized_endpoint = endpoint if endpoint.startswith("/") else f"/{endpoint}" + is_sandbox_endpoint = normalized_endpoint == "/sandbox" or normalized_endpoint.startswith( + "/sandbox/" + ) + if is_sandbox_endpoint and self.config.sandbox_base_url: + return self.config.sandbox_base_url + return self.base_url + @retry( retry=retry_if_exception(_is_idempotent_request_retryable_error), stop=stop_after_attempt(3), @@ -330,12 +349,13 @@ async def request( """Make an async request to the API""" self._check_auth_required() + base_url = self._base_url_for_endpoint(endpoint) if not endpoint.startswith("/"): endpoint = f"/api/v1/{endpoint}" else: endpoint = f"/api/v1{endpoint}" - url = f"{self.base_url}{endpoint}" + url = f"{base_url}{endpoint}" try: method_upper = method.upper() diff --git a/packages/prime-sandboxes/src/prime_sandboxes/core/config.py b/packages/prime-sandboxes/src/prime_sandboxes/core/config.py index deb59b3f8..77bedfff5 100644 --- a/packages/prime-sandboxes/src/prime_sandboxes/core/config.py +++ b/packages/prime-sandboxes/src/prime_sandboxes/core/config.py @@ -63,3 +63,12 @@ def base_url(self) -> str: if env_val: return self._strip_api_v1(env_val) return self._strip_api_v1(self.config.get("base_url", self.DEFAULT_BASE_URL)) + + @property + def sandbox_base_url(self) -> Optional[str]: + """Get sandbox API base URL with precedence: env > file > None.""" + env_val = os.getenv("PRIME_SANDBOX_BASE_URL") or os.getenv("PRIME_SANDBOX_INGRESS_URL") + value = env_val or self.config.get("sandbox_base_url") + if not value: + return None + return self._strip_api_v1(str(value)) diff --git a/packages/prime-sandboxes/tests/test_sandbox_base_url.py b/packages/prime-sandboxes/tests/test_sandbox_base_url.py new file mode 100644 index 000000000..6aa5a5029 --- /dev/null +++ b/packages/prime-sandboxes/tests/test_sandbox_base_url.py @@ -0,0 +1,117 @@ +import json +from pathlib import Path + +import httpx +import pytest + +from prime_sandboxes.core.client import APIClient, AsyncAPIClient + + +class RecordingTransport(httpx.BaseTransport): + def __init__(self) -> None: + self.requests: list[httpx.Request] = [] + + def handle_request(self, request: httpx.Request) -> httpx.Response: + self.requests.append(request) + return httpx.Response(200, request=request, json={"ok": True}) + + +class AsyncRecordingTransport(httpx.AsyncBaseTransport): + def __init__(self) -> None: + self.requests: list[httpx.Request] = [] + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + self.requests.append(request) + return httpx.Response(200, request=request, json={"ok": True}) + + +@pytest.fixture +def temp_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + monkeypatch.setenv("HOME", str(tmp_path)) + for name in ( + "PRIME_API_KEY", + "PRIME_API_BASE_URL", + "PRIME_BASE_URL", + "PRIME_SANDBOX_BASE_URL", + "PRIME_SANDBOX_INGRESS_URL", + ): + monkeypatch.delenv(name, raising=False) + return tmp_path + + +def write_config(home: Path, config: dict[str, str]) -> None: + config_dir = home / ".prime" + config_dir.mkdir() + (config_dir / "config.json").write_text(json.dumps(config)) + + +def test_sync_client_routes_only_sandbox_endpoints_to_sandbox_base_url( + temp_home: Path, +) -> None: + write_config( + temp_home, + { + "api_key": "test-key", + "base_url": "https://api.example", + "sandbox_base_url": "https://sandbox.example", + }, + ) + transport = RecordingTransport() + client = APIClient() + client.client = httpx.Client(transport=transport) + + assert client.request("GET", "/sandbox") == {"ok": True} + assert client.request("GET", "sandbox/sbx-1") == {"ok": True} + assert client.request("GET", "sandboxed") == {"ok": True} + assert client.request("GET", "/template/registry-credentials") == {"ok": True} + + assert [str(request.url) for request in transport.requests] == [ + "https://sandbox.example/api/v1/sandbox", + "https://sandbox.example/api/v1/sandbox/sbx-1", + "https://api.example/api/v1/sandboxed", + "https://api.example/api/v1/template/registry-credentials", + ] + + +def test_sync_client_sandbox_base_url_env_var_overrides_saved_config( + temp_home: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + write_config( + temp_home, + { + "api_key": "test-key", + "base_url": "https://api.example", + "sandbox_base_url": "https://saved-sandbox.example", + }, + ) + monkeypatch.setenv("PRIME_SANDBOX_BASE_URL", "https://env-sandbox.example/api/v1") + transport = RecordingTransport() + client = APIClient() + client.client = httpx.Client(transport=transport) + + assert client.request("GET", "/sandbox/sbx-1") == {"ok": True} + + assert str(transport.requests[0].url) == "https://env-sandbox.example/api/v1/sandbox/sbx-1" + + +@pytest.mark.asyncio +async def test_async_client_routes_sandbox_endpoints_to_sandbox_base_url( + temp_home: Path, +) -> None: + write_config( + temp_home, + { + "api_key": "test-key", + "base_url": "https://api.example", + "sandbox_base_url": "https://sandbox.example/api/v1", + }, + ) + transport = AsyncRecordingTransport() + client = AsyncAPIClient() + client.client = httpx.AsyncClient(transport=transport) + + assert await client.request("GET", "/sandbox/sbx-1/auth") == {"ok": True} + + assert str(transport.requests[0].url) == "https://sandbox.example/api/v1/sandbox/sbx-1/auth" + await client.client.aclose() diff --git a/packages/prime/src/prime_cli/commands/config.py b/packages/prime/src/prime_cli/commands/config.py index bcdd61280..12d56992e 100644 --- a/packages/prime/src/prime_cli/commands/config.py +++ b/packages/prime/src/prime_cli/commands/config.py @@ -83,6 +83,12 @@ def _env_set(*names: str) -> bool: base_label += " (from env var)" table.add_row("Base URL", base_label) + # Show sandbox base URL + sandbox_base_label = settings.get("sandbox_base_url") or "Not set" + if _env_set("PRIME_SANDBOX_BASE_URL", "PRIME_SANDBOX_INGRESS_URL"): + sandbox_base_label += " (from env var)" + table.add_row("Sandbox Base URL", sandbox_base_label) + # Show frontend URL front_label = settings["frontend_url"] if _env_set("PRIME_FRONTEND_URL"): @@ -227,6 +233,27 @@ def set_base_url( console.print(f"[green]Base URL set to: {url}[/green]") +@app.command() +def set_sandbox_base_url( + url: Optional[str] = typer.Argument( + None, + help="Base URL for the sandbox API. If not provided, you'll be prompted.", + ), +) -> None: + """Set the sandbox API base URL (prompts if not provided)""" + if not url: + config = Config() + url = typer.prompt( + "Enter the base URL for the sandbox API", + default=config.sandbox_base_url or "", + ) + + config = Config() + config.set_sandbox_base_url(url or None) + config.update_current_environment_file() + console.print(f"[green]Sandbox Base URL set to: {url or 'Not set'}[/green]") + + @app.command() def set_frontend_url( url: Optional[str] = typer.Argument( @@ -382,6 +409,7 @@ def reset( config.set_api_key("") config.set_team(None) config.set_base_url(Config.DEFAULT_BASE_URL) + config.set_sandbox_base_url(None) config.set_frontend_url(Config.DEFAULT_FRONTEND_URL) config.set_ssh_key_path(Config.DEFAULT_SSH_KEY_PATH) config.set_current_environment("production") diff --git a/packages/prime/src/prime_cli/core/config.py b/packages/prime/src/prime_cli/core/config.py index b3ef6e80e..1b7df94c3 100644 --- a/packages/prime/src/prime_cli/core/config.py +++ b/packages/prime/src/prime_cli/core/config.py @@ -14,6 +14,7 @@ class ConfigModel(BaseModel): team_role: str | None = None user_id: str | None = None base_url: str = "https://api.primeintellect.ai" + sandbox_base_url: str | None = None frontend_url: str = "https://app.primeintellect.ai" inference_url: str = "https://api.pinference.ai/api/v1" ssh_key_path: str = str(Path.home() / ".ssh" / "id_rsa") @@ -57,6 +58,7 @@ def _ensure_config_dir(self) -> None: team_id=None, user_id=None, base_url=self.DEFAULT_BASE_URL, + sandbox_base_url=None, frontend_url=self.DEFAULT_FRONTEND_URL, inference_url=self.DEFAULT_INFERENCE_URL, ssh_key_path=self.DEFAULT_SSH_KEY_PATH, @@ -149,6 +151,20 @@ def set_base_url(self, value: str) -> None: self.config["base_url"] = value self._save_config(self.config) + @property + def sandbox_base_url(self) -> Optional[str]: + """Get sandbox API base URL with precedence: env > file > None.""" + env_val = os.getenv("PRIME_SANDBOX_BASE_URL") or os.getenv("PRIME_SANDBOX_INGRESS_URL") + value = env_val or self.config.get("sandbox_base_url") + if not value: + return None + return self._strip_api_v1(str(value)) + + def set_sandbox_base_url(self, value: str | None) -> None: + """Set sandbox API base URL in config file.""" + self.config["sandbox_base_url"] = self._strip_api_v1(value) if value else None + self._save_config(self.config) + @property def frontend_url(self) -> str: """Get frontend URL with precedence: env > file > default.""" @@ -234,6 +250,7 @@ def view(self) -> dict: "team_role": self.team_role, "user_id": self.user_id, "base_url": self.base_url, + "sandbox_base_url": self.sandbox_base_url, "frontend_url": self.frontend_url, "inference_url": self.inference_url, "ssh_key_path": self.ssh_key_path, @@ -255,6 +272,7 @@ def save_environment(self, name: str) -> None: "team_role": None if self.team_id_from_env else self.team_role, "user_id": self.user_id, "base_url": self.base_url, + "sandbox_base_url": self.sandbox_base_url, "frontend_url": self.frontend_url, "inference_url": self.inference_url, } @@ -292,12 +310,14 @@ def load_environment(self, name: str, persist: bool = True) -> bool: # Built-in production environment if persist: self.set_base_url(self.DEFAULT_BASE_URL) + self.set_sandbox_base_url(None) self.set_frontend_url(self.DEFAULT_FRONTEND_URL) self.set_inference_url(self.DEFAULT_INFERENCE_URL) self.set_team(None) # Production defaults to personal account self.set_current_environment("production") else: self.config["base_url"] = self.DEFAULT_BASE_URL + self.config["sandbox_base_url"] = None self.config["frontend_url"] = self.DEFAULT_FRONTEND_URL self.config["inference_url"] = self.DEFAULT_INFERENCE_URL self.config["team_id"] = None @@ -327,6 +347,7 @@ def load_environment(self, name: str, persist: bool = True) -> bool: # Set user_id from environment self.set_user_id(env_config.get("user_id", None)) self.set_base_url(env_config.get("base_url", self.DEFAULT_BASE_URL)) + self.set_sandbox_base_url(env_config.get("sandbox_base_url", None)) self.set_frontend_url(env_config.get("frontend_url", self.DEFAULT_FRONTEND_URL)) self.set_inference_url( env_config.get("inference_url", self.DEFAULT_INFERENCE_URL) @@ -343,6 +364,10 @@ def load_environment(self, name: str, persist: bool = True) -> bool: # Normalize URLs the same way set_* methods do base_url = env_config.get("base_url", self.DEFAULT_BASE_URL) self.config["base_url"] = self._strip_api_v1(base_url) + sandbox_base_url = env_config.get("sandbox_base_url", None) + self.config["sandbox_base_url"] = ( + self._strip_api_v1(sandbox_base_url) if sandbox_base_url else None + ) frontend_url = env_config.get("frontend_url", self.DEFAULT_FRONTEND_URL) self.config["frontend_url"] = frontend_url.rstrip("/") inference_url = env_config.get("inference_url", self.DEFAULT_INFERENCE_URL) @@ -369,6 +394,7 @@ def update_current_environment_file(self) -> None: "team_role": None if self.team_id_from_env else self.team_role, "user_id": self.user_id, "base_url": self.base_url, + "sandbox_base_url": self.sandbox_base_url, "frontend_url": self.frontend_url, "inference_url": self.inference_url, } diff --git a/packages/prime/tests/test_config_sandbox_base_url.py b/packages/prime/tests/test_config_sandbox_base_url.py new file mode 100644 index 000000000..c985afcb0 --- /dev/null +++ b/packages/prime/tests/test_config_sandbox_base_url.py @@ -0,0 +1,111 @@ +import json +from pathlib import Path +from typing import Any + +import pytest +from prime_cli.core import Config +from prime_cli.main import app +from typer.testing import CliRunner + +runner = CliRunner() + +TEST_ENV = { + "COLUMNS": "200", + "LINES": "50", + "PRIME_DISABLE_VERSION_CHECK": "1", +} + + +@pytest.fixture +def temp_home(tmp_path: Any, monkeypatch: pytest.MonkeyPatch) -> Path: + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.delenv("PRIME_SANDBOX_BASE_URL", raising=False) + monkeypatch.delenv("PRIME_SANDBOX_INGRESS_URL", raising=False) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + return tmp_path + + +def test_sandbox_base_url_is_saved_with_environment(temp_home: Path) -> None: + config = Config() + config.set_base_url("https://api.dev.example/api/v1") + config.set_sandbox_base_url("https://sandbox.dev.example/api/v1") + config.save_environment("dev") + + env_file = temp_home / ".prime" / "environments" / "dev.json" + env_config = json.loads(env_file.read_text()) + + assert config.sandbox_base_url == "https://sandbox.dev.example" + assert config.view()["sandbox_base_url"] == "https://sandbox.dev.example" + assert env_config["base_url"] == "https://api.dev.example" + assert env_config["sandbox_base_url"] == "https://sandbox.dev.example" + + assert config.load_environment("production") is True + assert config.sandbox_base_url is None + + assert config.load_environment("dev") is True + assert config.base_url == "https://api.dev.example" + assert config.sandbox_base_url == "https://sandbox.dev.example" + + +def test_sandbox_base_url_env_var_overrides_saved_config( + temp_home: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + config = Config() + config.set_sandbox_base_url("https://saved-sandbox.example/api/v1") + + monkeypatch.setenv("PRIME_SANDBOX_BASE_URL", "https://env-sandbox.example/api/v1") + + assert config.sandbox_base_url == "https://env-sandbox.example" + assert config.view()["sandbox_base_url"] == "https://env-sandbox.example" + + +def test_sandbox_ingress_env_var_is_supported( + temp_home: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + config = Config() + config.set_sandbox_base_url("https://saved-sandbox.example/api/v1") + + monkeypatch.setenv("PRIME_SANDBOX_INGRESS_URL", "https://ingress-sandbox.example/api/v1") + + assert config.sandbox_base_url == "https://ingress-sandbox.example" + + +def test_set_sandbox_base_url_command_updates_config_view(temp_home: Path) -> None: + set_result = runner.invoke( + app, + ["config", "set-sandbox-base-url", "https://sandbox.dev.example/api/v1"], + env=TEST_ENV, + ) + + assert set_result.exit_code == 0, set_result.output + + view_result = runner.invoke(app, ["config", "view"], env=TEST_ENV) + + assert view_result.exit_code == 0, view_result.output + assert "Sandbox Base URL" in view_result.output + assert "https://sandbox.dev.example" in view_result.output + + +def test_set_sandbox_base_url_command_updates_active_environment_file(temp_home: Path) -> None: + config = Config() + config.set_base_url("https://api.dev.example") + config.save_environment("dev") + config.load_environment("dev") + + result = runner.invoke( + app, + ["config", "set-sandbox-base-url", "https://sandbox.dev.example/api/v1"], + env=TEST_ENV, + ) + + assert result.exit_code == 0, result.output + + env_file = temp_home / ".prime" / "environments" / "dev.json" + env_config = json.loads(env_file.read_text()) + + assert env_config["sandbox_base_url"] == "https://sandbox.dev.example" + + reloaded = Config() + assert reloaded.load_environment("production") is True + assert reloaded.load_environment("dev") is True + assert reloaded.sandbox_base_url == "https://sandbox.dev.example"