Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions packages/prime-sandboxes/src/prime_sandboxes/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ def _check_auth_required(self) -> None:
"No API key configured. Set PRIME_API_KEY environment variable.",
)

def _base_url_for_endpoint(self, endpoint: str) -> str:
normalized_endpoint = endpoint if endpoint.startswith("/") else f"/{endpoint}"
is_sandbox_endpoint = normalized_endpoint == "/sandbox" or normalized_endpoint.startswith(
"/sandbox/"
)
if is_sandbox_endpoint and self.config.sandbox_base_url:
return self.config.sandbox_base_url
return self.base_url

@retry(
retry=retry_if_exception(_is_idempotent_request_retryable_error),
stop=stop_after_attempt(3),
Expand Down Expand Up @@ -169,12 +178,13 @@ def request(
"""Make a request to the API"""
self._check_auth_required()

base_url = self._base_url_for_endpoint(endpoint)
if not endpoint.startswith("/"):
endpoint = f"/api/v1/{endpoint}"
else:
endpoint = f"/api/v1{endpoint}"

url = f"{self.base_url}{endpoint}"
url = f"{base_url}{endpoint}"

try:
method_upper = method.upper()
Expand Down Expand Up @@ -263,6 +273,15 @@ def _check_auth_required(self) -> None:
"No API key configured. Set PRIME_API_KEY environment variable.",
)

def _base_url_for_endpoint(self, endpoint: str) -> str:
normalized_endpoint = endpoint if endpoint.startswith("/") else f"/{endpoint}"
is_sandbox_endpoint = normalized_endpoint == "/sandbox" or normalized_endpoint.startswith(
"/sandbox/"
)
if is_sandbox_endpoint and self.config.sandbox_base_url:
return self.config.sandbox_base_url
return self.base_url

@retry(
retry=retry_if_exception(_is_idempotent_request_retryable_error),
stop=stop_after_attempt(3),
Expand Down Expand Up @@ -330,12 +349,13 @@ async def request(
"""Make an async request to the API"""
self._check_auth_required()

base_url = self._base_url_for_endpoint(endpoint)
if not endpoint.startswith("/"):
endpoint = f"/api/v1/{endpoint}"
else:
endpoint = f"/api/v1{endpoint}"

url = f"{self.base_url}{endpoint}"
url = f"{base_url}{endpoint}"

try:
method_upper = method.upper()
Expand Down
9 changes: 9 additions & 0 deletions packages/prime-sandboxes/src/prime_sandboxes/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,12 @@ def base_url(self) -> str:
if env_val:
return self._strip_api_v1(env_val)
return self._strip_api_v1(self.config.get("base_url", self.DEFAULT_BASE_URL))

@property
def sandbox_base_url(self) -> Optional[str]:
"""Get sandbox API base URL with precedence: env > file > None."""
env_val = os.getenv("PRIME_SANDBOX_BASE_URL") or os.getenv("PRIME_SANDBOX_INGRESS_URL")
value = env_val or self.config.get("sandbox_base_url")
if not value:
return None
return self._strip_api_v1(str(value))
117 changes: 117 additions & 0 deletions packages/prime-sandboxes/tests/test_sandbox_base_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import json
from pathlib import Path

import httpx
import pytest

from prime_sandboxes.core.client import APIClient, AsyncAPIClient


class RecordingTransport(httpx.BaseTransport):
def __init__(self) -> None:
self.requests: list[httpx.Request] = []

def handle_request(self, request: httpx.Request) -> httpx.Response:
self.requests.append(request)
return httpx.Response(200, request=request, json={"ok": True})


class AsyncRecordingTransport(httpx.AsyncBaseTransport):
def __init__(self) -> None:
self.requests: list[httpx.Request] = []

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
self.requests.append(request)
return httpx.Response(200, request=request, json={"ok": True})


@pytest.fixture
def temp_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
monkeypatch.setenv("HOME", str(tmp_path))
for name in (
"PRIME_API_KEY",
"PRIME_API_BASE_URL",
"PRIME_BASE_URL",
"PRIME_SANDBOX_BASE_URL",
"PRIME_SANDBOX_INGRESS_URL",
):
monkeypatch.delenv(name, raising=False)
return tmp_path


def write_config(home: Path, config: dict[str, str]) -> None:
config_dir = home / ".prime"
config_dir.mkdir()
(config_dir / "config.json").write_text(json.dumps(config))


def test_sync_client_routes_only_sandbox_endpoints_to_sandbox_base_url(
temp_home: Path,
) -> None:
write_config(
temp_home,
{
"api_key": "test-key",
"base_url": "https://api.example",
"sandbox_base_url": "https://sandbox.example",
},
)
transport = RecordingTransport()
client = APIClient()
client.client = httpx.Client(transport=transport)

assert client.request("GET", "/sandbox") == {"ok": True}
assert client.request("GET", "sandbox/sbx-1") == {"ok": True}
assert client.request("GET", "sandboxed") == {"ok": True}
assert client.request("GET", "/template/registry-credentials") == {"ok": True}

assert [str(request.url) for request in transport.requests] == [
"https://sandbox.example/api/v1/sandbox",
"https://sandbox.example/api/v1/sandbox/sbx-1",
"https://api.example/api/v1/sandboxed",
"https://api.example/api/v1/template/registry-credentials",
]


def test_sync_client_sandbox_base_url_env_var_overrides_saved_config(
temp_home: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
write_config(
temp_home,
{
"api_key": "test-key",
"base_url": "https://api.example",
"sandbox_base_url": "https://saved-sandbox.example",
},
)
monkeypatch.setenv("PRIME_SANDBOX_BASE_URL", "https://env-sandbox.example/api/v1")
transport = RecordingTransport()
client = APIClient()
client.client = httpx.Client(transport=transport)

assert client.request("GET", "/sandbox/sbx-1") == {"ok": True}

assert str(transport.requests[0].url) == "https://env-sandbox.example/api/v1/sandbox/sbx-1"


@pytest.mark.asyncio
async def test_async_client_routes_sandbox_endpoints_to_sandbox_base_url(
temp_home: Path,
) -> None:
write_config(
temp_home,
{
"api_key": "test-key",
"base_url": "https://api.example",
"sandbox_base_url": "https://sandbox.example/api/v1",
},
)
transport = AsyncRecordingTransport()
client = AsyncAPIClient()
client.client = httpx.AsyncClient(transport=transport)

assert await client.request("GET", "/sandbox/sbx-1/auth") == {"ok": True}

assert str(transport.requests[0].url) == "https://sandbox.example/api/v1/sandbox/sbx-1/auth"
await client.client.aclose()
28 changes: 28 additions & 0 deletions packages/prime/src/prime_cli/commands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def _env_set(*names: str) -> bool:
base_label += " (from env var)"
table.add_row("Base URL", base_label)

# Show sandbox base URL
sandbox_base_label = settings.get("sandbox_base_url") or "Not set"
if _env_set("PRIME_SANDBOX_BASE_URL", "PRIME_SANDBOX_INGRESS_URL"):
sandbox_base_label += " (from env var)"
table.add_row("Sandbox Base URL", sandbox_base_label)

# Show frontend URL
front_label = settings["frontend_url"]
if _env_set("PRIME_FRONTEND_URL"):
Expand Down Expand Up @@ -227,6 +233,27 @@ def set_base_url(
console.print(f"[green]Base URL set to: {url}[/green]")


@app.command()
def set_sandbox_base_url(
url: Optional[str] = typer.Argument(
None,
help="Base URL for the sandbox API. If not provided, you'll be prompted.",
),
) -> None:
"""Set the sandbox API base URL (prompts if not provided)"""
if not url:
config = Config()
url = typer.prompt(
"Enter the base URL for the sandbox API",
default=config.sandbox_base_url or "",
)

config = Config()
config.set_sandbox_base_url(url or None)
config.update_current_environment_file()
console.print(f"[green]Sandbox Base URL set to: {url or 'Not set'}[/green]")


@app.command()
def set_frontend_url(
url: Optional[str] = typer.Argument(
Expand Down Expand Up @@ -382,6 +409,7 @@ def reset(
config.set_api_key("")
config.set_team(None)
config.set_base_url(Config.DEFAULT_BASE_URL)
config.set_sandbox_base_url(None)
config.set_frontend_url(Config.DEFAULT_FRONTEND_URL)
config.set_ssh_key_path(Config.DEFAULT_SSH_KEY_PATH)
config.set_current_environment("production")
Expand Down
26 changes: 26 additions & 0 deletions packages/prime/src/prime_cli/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class ConfigModel(BaseModel):
team_role: str | None = None
user_id: str | None = None
base_url: str = "https://api.primeintellect.ai"
sandbox_base_url: str | None = None
frontend_url: str = "https://app.primeintellect.ai"
inference_url: str = "https://api.pinference.ai/api/v1"
ssh_key_path: str = str(Path.home() / ".ssh" / "id_rsa")
Expand Down Expand Up @@ -57,6 +58,7 @@ def _ensure_config_dir(self) -> None:
team_id=None,
user_id=None,
base_url=self.DEFAULT_BASE_URL,
sandbox_base_url=None,
frontend_url=self.DEFAULT_FRONTEND_URL,
inference_url=self.DEFAULT_INFERENCE_URL,
ssh_key_path=self.DEFAULT_SSH_KEY_PATH,
Expand Down Expand Up @@ -149,6 +151,20 @@ def set_base_url(self, value: str) -> None:
self.config["base_url"] = value
self._save_config(self.config)

@property
def sandbox_base_url(self) -> Optional[str]:
"""Get sandbox API base URL with precedence: env > file > None."""
env_val = os.getenv("PRIME_SANDBOX_BASE_URL") or os.getenv("PRIME_SANDBOX_INGRESS_URL")
value = env_val or self.config.get("sandbox_base_url")
if not value:
return None
return self._strip_api_v1(str(value))

def set_sandbox_base_url(self, value: str | None) -> None:
"""Set sandbox API base URL in config file."""
self.config["sandbox_base_url"] = self._strip_api_v1(value) if value else None
self._save_config(self.config)

@property
def frontend_url(self) -> str:
"""Get frontend URL with precedence: env > file > default."""
Expand Down Expand Up @@ -234,6 +250,7 @@ def view(self) -> dict:
"team_role": self.team_role,
"user_id": self.user_id,
"base_url": self.base_url,
"sandbox_base_url": self.sandbox_base_url,
"frontend_url": self.frontend_url,
"inference_url": self.inference_url,
"ssh_key_path": self.ssh_key_path,
Expand All @@ -255,6 +272,7 @@ def save_environment(self, name: str) -> None:
"team_role": None if self.team_id_from_env else self.team_role,
"user_id": self.user_id,
"base_url": self.base_url,
"sandbox_base_url": self.sandbox_base_url,
"frontend_url": self.frontend_url,
"inference_url": self.inference_url,
}
Expand Down Expand Up @@ -292,12 +310,14 @@ def load_environment(self, name: str, persist: bool = True) -> bool:
# Built-in production environment
if persist:
self.set_base_url(self.DEFAULT_BASE_URL)
self.set_sandbox_base_url(None)
self.set_frontend_url(self.DEFAULT_FRONTEND_URL)
self.set_inference_url(self.DEFAULT_INFERENCE_URL)
self.set_team(None) # Production defaults to personal account
self.set_current_environment("production")
else:
self.config["base_url"] = self.DEFAULT_BASE_URL
self.config["sandbox_base_url"] = None
self.config["frontend_url"] = self.DEFAULT_FRONTEND_URL
self.config["inference_url"] = self.DEFAULT_INFERENCE_URL
self.config["team_id"] = None
Expand Down Expand Up @@ -327,6 +347,7 @@ def load_environment(self, name: str, persist: bool = True) -> bool:
# Set user_id from environment
self.set_user_id(env_config.get("user_id", None))
self.set_base_url(env_config.get("base_url", self.DEFAULT_BASE_URL))
self.set_sandbox_base_url(env_config.get("sandbox_base_url", None))
self.set_frontend_url(env_config.get("frontend_url", self.DEFAULT_FRONTEND_URL))
self.set_inference_url(
env_config.get("inference_url", self.DEFAULT_INFERENCE_URL)
Expand All @@ -343,6 +364,10 @@ def load_environment(self, name: str, persist: bool = True) -> bool:
# Normalize URLs the same way set_* methods do
base_url = env_config.get("base_url", self.DEFAULT_BASE_URL)
self.config["base_url"] = self._strip_api_v1(base_url)
sandbox_base_url = env_config.get("sandbox_base_url", None)
self.config["sandbox_base_url"] = (
self._strip_api_v1(sandbox_base_url) if sandbox_base_url else None
)
frontend_url = env_config.get("frontend_url", self.DEFAULT_FRONTEND_URL)
self.config["frontend_url"] = frontend_url.rstrip("/")
inference_url = env_config.get("inference_url", self.DEFAULT_INFERENCE_URL)
Expand All @@ -369,6 +394,7 @@ def update_current_environment_file(self) -> None:
"team_role": None if self.team_id_from_env else self.team_role,
"user_id": self.user_id,
"base_url": self.base_url,
"sandbox_base_url": self.sandbox_base_url,
"frontend_url": self.frontend_url,
"inference_url": self.inference_url,
}
Expand Down
Loading
Loading