diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5308ad1f..1e424037 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: if: github.event_name == 'pull_request' runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: fetch-depth: 0 @@ -68,7 +68,7 @@ jobs: lint-format: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Run Ruff uses: chartboost/ruff-action@v1 @@ -106,7 +106,7 @@ jobs: matrix: python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v6 @@ -114,9 +114,10 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install uv - uses: astral-sh/setup-uv@v4 + uses: astral-sh/setup-uv@v8.1.0 with: - version: "latest" + version: "0.11.15" + enable-cache: false - name: Install dependencies working-directory: packages/prime-sandboxes @@ -135,7 +136,7 @@ jobs: matrix: python-version: ["3.11", "3.12", "3.13", "3.14"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v6 @@ -143,9 +144,10 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install uv - uses: astral-sh/setup-uv@v4 + uses: astral-sh/setup-uv@v8.1.0 with: - version: "latest" + version: "0.11.15" + enable-cache: false - name: Install dependencies working-directory: packages/prime @@ -161,7 +163,7 @@ jobs: type-check: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Python uses: actions/setup-python@v6 @@ -169,9 +171,10 @@ jobs: python-version: '3.11' - name: Install uv - uses: astral-sh/setup-uv@v4 + uses: astral-sh/setup-uv@v8.1.0 with: - version: "latest" + version: "0.11.15" + enable-cache: false - name: Install dependencies working-directory: packages/prime diff --git a/README.md b/README.md index b8a13c58..68f2a2ec 100644 --- a/README.md +++ b/README.md @@ -152,8 +152,10 @@ prime env push my-environment Prime Lab connects verifiers environments to evaluations, GEPA prompt optimization, and Hosted Training. Start with `prime lab setup` to create a local workspace with starter configs, then use `prime train models` to choose a Hosted Training model with current capacity and pricing. ```bash -# Set up a Lab workspace +# Set up a Lab workspace. +# If authenticated, setup creates an active project named after this folder. prime lab setup +prime project current # List trainable models, capacity, and token pricing prime train models @@ -161,8 +163,11 @@ prime train models # Generate a Hosted Training config prime train init -# Launch the run from the generated config +# Launch the run from the generated config. +# Runs attach to the active project by default. prime train rl.toml +prime train rl.toml --project +prime train rl.toml --no-project # Inspect and manage Hosted Training runs prime train list @@ -171,6 +176,30 @@ prime train metrics prime train checkpoints ``` +Lab projects group related training runs, evaluations, and adapters. By default, +`prime lab setup` creates an active project named after the workspace folder. +Use `prime lab setup --project ` to bind an existing project, +`prime lab setup --project-name "Alphabet Sort Baselines"` to choose the default +project name, or `prime lab setup --no-project` to keep setup local-only. Later, +use `prime project use ` to switch the active workspace project, or +`prime project clear` to stop using one by default. Existing runs and adapters +support project add/remove/clear; evaluations support assign/clear. + +```bash +# Manage projects +prime project list +prime project show +prime project update --description "Baseline alphabet sort runs" + +# Attach existing artifacts +prime project assign run +prime project remove run +prime project assign adapter +prime project remove adapter # clear all adapter project memberships +prime project assign eval +prime project remove eval # clear the evaluation project +``` + ### GPU Resources ```bash @@ -211,6 +240,8 @@ prime eval push # Push specific eval directory (verifiers format) prime eval push outputs/evals/gsm8k--gpt-4/abc123 +prime eval push outputs/evals/gsm8k--gpt-4/abc123 --project +prime eval push outputs/evals/gsm8k--gpt-4/abc123 --no-project # Push a public evaluation (default is private) prime eval push --public diff --git a/packages/prime-evals/README.md b/packages/prime-evals/README.md index 52ee41fb..d677b018 100644 --- a/packages/prime-evals/README.md +++ b/packages/prime-evals/README.md @@ -37,6 +37,7 @@ eval_response = client.create_evaluation( model_name="gpt-4o-mini", dataset="gsm8k", framework="verifiers", + project_id="project-id", metadata={ "version": "1.0", "num_examples": 10, @@ -220,6 +221,29 @@ client.finalize_evaluation(eval_id, metrics=eval_data.get("metrics")) print(f"Successfully pushed evaluation: {eval_id}") ``` +## Project Attachment + +Evaluations can be created inside a Lab project, moved to another project, or +cleared from their project. Evaluation assignment is set/clear; targeted removal +from one project is not supported for evaluations. + +```python +eval_response = client.create_evaluation( + name="gsm8k-project-baseline", + environments=[{"id": "gsm8k"}], + model_name="gpt-4o-mini", + project_id="project-id", +) + +eval_id = eval_response["evaluation_id"] + +# Move the evaluation to another project +client.update_evaluation(eval_id, project_id="another-project-id") + +# Clear the evaluation project +client.update_evaluation(eval_id, clear_project=True) +``` + ## API Reference ### EvalsClient @@ -232,6 +256,7 @@ Main client for interacting with the Prime Evals API. - `push_samples()` - Push evaluation samples - `finalize_evaluation()` - Finalize an evaluation with final metrics - `get_evaluation()` - Get evaluation details by ID +- `update_evaluation()` - Update evaluation details or assign/clear a project - `list_evaluations()` - List evaluations with optional filters - `get_samples()` - Get samples for an evaluation @@ -276,4 +301,3 @@ except EvalsAPIError as e: ## License MIT License - see LICENSE file for details - diff --git a/packages/prime-evals/src/prime_evals/evals.py b/packages/prime-evals/src/prime_evals/evals.py index d4a6e35c..6ffa5fb0 100644 --- a/packages/prime-evals/src/prime_evals/evals.py +++ b/packages/prime-evals/src/prime_evals/evals.py @@ -161,6 +161,7 @@ def create_evaluation( task_type: Optional[str] = None, description: Optional[str] = None, tags: Optional[List[str]] = None, + project_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, metrics: Optional[Dict[str, Any]] = None, is_public: Optional[bool] = None, @@ -204,6 +205,7 @@ def create_evaluation( "task_type": task_type, "description": description, "tags": tags or [], + "project_id": project_id, "metadata": metadata, "metrics": metrics, } @@ -367,6 +369,8 @@ def update_evaluation( task_type: Optional[str] = None, description: Optional[str] = None, tags: Optional[List[str]] = None, + project_id: Optional[str] = None, + clear_project: bool = False, metadata: Optional[Dict[str, Any]] = None, metrics: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: @@ -377,11 +381,16 @@ def update_evaluation( "framework": framework, "task_type": task_type, "description": description, - "tags": tags if tags is not None else [], + "tags": tags, + "project_id": project_id, "metadata": metadata, "metrics": metrics, } - payload = {k: v for k, v in payload.items() if v is not None or k in ["tags"]} + payload = { + k: v + for k, v in payload.items() + if v is not None or (clear_project and k == "project_id") + } response = self.client.request("PUT", f"/evaluations/{evaluation_id}", json=payload) return response @@ -519,6 +528,7 @@ async def create_evaluation( task_type: Optional[str] = None, description: Optional[str] = None, tags: Optional[List[str]] = None, + project_id: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, metrics: Optional[Dict[str, Any]] = None, is_public: Optional[bool] = None, @@ -562,6 +572,7 @@ async def create_evaluation( "task_type": task_type, "description": description, "tags": tags or [], + "project_id": project_id, "metadata": metadata, "metrics": metrics, } @@ -719,6 +730,8 @@ async def update_evaluation( task_type: Optional[str] = None, description: Optional[str] = None, tags: Optional[List[str]] = None, + project_id: Optional[str] = None, + clear_project: bool = False, metadata: Optional[Dict[str, Any]] = None, metrics: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: @@ -729,11 +742,16 @@ async def update_evaluation( "framework": framework, "task_type": task_type, "description": description, - "tags": tags if tags is not None else [], + "tags": tags, + "project_id": project_id, "metadata": metadata, "metrics": metrics, } - payload = {k: v for k, v in payload.items() if v is not None or k in ["tags"]} + payload = { + k: v + for k, v in payload.items() + if v is not None or (clear_project and k == "project_id") + } response = await self.client.request("PUT", f"/evaluations/{evaluation_id}", json=payload) return response diff --git a/packages/prime-evals/src/prime_evals/models.py b/packages/prime-evals/src/prime_evals/models.py index 7d278dcf..953387a0 100644 --- a/packages/prime-evals/src/prime_evals/models.py +++ b/packages/prime-evals/src/prime_evals/models.py @@ -32,6 +32,7 @@ class Evaluation(BaseModel): run_id: Optional[str] = Field(None, alias="runId") version_id: Optional[str] = Field(None, alias="versionId") tags: List[str] = Field(default_factory=list) + project_id: Optional[str] = Field(None, alias="projectId") metadata: Optional[Dict[str, Any]] = None metrics: Optional[Dict[str, Any]] = None total_samples: Optional[int] = Field(None, alias="totalSamples") @@ -66,6 +67,7 @@ class CreateEvaluationRequest(BaseModel): task_type: Optional[str] = None description: Optional[str] = None tags: List[str] = Field(default_factory=list) + project_id: Optional[str] = None metadata: Optional[Dict[str, Any]] = None metrics: Optional[Dict[str, Any]] = None diff --git a/packages/prime-evals/tests/test_evals.py b/packages/prime-evals/tests/test_evals.py index 016c24c3..e845056c 100644 --- a/packages/prime-evals/tests/test_evals.py +++ b/packages/prime-evals/tests/test_evals.py @@ -232,6 +232,61 @@ def test_evals_client_context_manager(): pass # Expected to fail without proper initialization +def test_create_evaluation_sends_project_id_payload(): + captured = {} + + class DummyConfig: + team_id = None + + class DummyHTTPClient: + config = DummyConfig() + + def request(self, method, endpoint, json=None, params=None): + captured["method"] = method + captured["endpoint"] = endpoint + captured["json"] = json + captured["params"] = params + return {"evaluation_id": "eval-123"} + + client = EvalsClient.__new__(EvalsClient) + client.client = DummyHTTPClient() + + response = client.create_evaluation( + name="gsm8k", + run_id="run-123", + model_name="gpt-4o-mini", + project_id="project-123", + ) + + assert response == {"evaluation_id": "eval-123"} + assert captured["method"] == "POST" + assert captured["endpoint"] == "/evaluations/" + assert captured["json"]["project_id"] == "project-123" + assert "projectId" not in captured["json"] + + +def test_update_evaluation_clear_project_sends_null_project_id(): + captured = {} + + class DummyHTTPClient: + def request(self, method, endpoint, json=None, params=None): + captured["method"] = method + captured["endpoint"] = endpoint + captured["json"] = json + captured["params"] = params + return {"evaluation_id": "eval-123"} + + client = EvalsClient.__new__(EvalsClient) + client.client = DummyHTTPClient() + + response = client.update_evaluation("eval-123", clear_project=True) + + assert response == {"evaluation_id": "eval-123"} + assert captured["method"] == "PUT" + assert captured["endpoint"] == "/evaluations/eval-123" + assert captured["json"] == {"project_id": None} + + def test_evaluation_model_minimal(): """Test Evaluation model with minimal data""" data = { diff --git a/packages/prime/README.md b/packages/prime/README.md index 470f85f6..9f25dabe 100644 --- a/packages/prime/README.md +++ b/packages/prime/README.md @@ -114,8 +114,10 @@ prime sandbox create python:3.11 Prime Lab connects verifiers environments to evaluations, GEPA prompt optimization, and Hosted Training. Start with `prime lab setup` to create a local workspace with starter configs, then use `prime train models` to choose a Hosted Training model with current capacity and pricing. ```bash -# Set up a Lab workspace +# Set up a Lab workspace. +# If authenticated, setup creates an active project named after this folder. prime lab setup +prime project current # List trainable models, capacity, and token pricing prime train models @@ -123,8 +125,11 @@ prime train models # Generate a Hosted Training config prime train init -# Launch the run from the generated config +# Launch the run from the generated config. +# Runs attach to the active project by default. prime train rl.toml +prime train rl.toml --project +prime train rl.toml --no-project # Inspect and manage Hosted Training runs prime train list @@ -133,6 +138,30 @@ prime train metrics prime train checkpoints ``` +Lab projects group related training runs, evaluations, and adapters. By default, +`prime lab setup` creates an active project named after the workspace folder. +Use `prime lab setup --project ` to bind an existing project, +`prime lab setup --project-name "Alphabet Sort Baselines"` to choose the default +project name, or `prime lab setup --no-project` to keep setup local-only. Later, +use `prime project use ` to switch the active workspace project, or +`prime project clear` to stop using one by default. Existing runs and adapters +support project add/remove/clear; evaluations support assign/clear. + +```bash +# Manage projects +prime project list +prime project show +prime project update --description "Baseline alphabet sort runs" + +# Attach existing artifacts +prime project assign run +prime project remove run +prime project assign adapter +prime project remove adapter # clear all adapter project memberships +prime project assign eval +prime project remove eval # clear the evaluation project +``` + ### Environments Hub Access hundreds of RL environments on our community hub with deep integrations with sandboxes, training, and evaluation stack. diff --git a/packages/prime/src/prime_cli/api/deployments.py b/packages/prime/src/prime_cli/api/deployments.py index dee21359..7bd1e6c6 100644 --- a/packages/prime/src/prime_cli/api/deployments.py +++ b/packages/prime/src/prime_cli/api/deployments.py @@ -15,6 +15,7 @@ class Adapter(BaseModel): display_name: Optional[str] = Field(None, alias="displayName") user_id: str = Field(..., alias="userId") team_id: Optional[str] = Field(None, alias="teamId") + project_id: Optional[str] = Field(None, alias="projectId") rft_run_id: str = Field(..., alias="rftRunId") base_model: str = Field(..., alias="baseModel") step: Optional[int] = Field(None, description="Training step number") @@ -92,6 +93,26 @@ def unload_adapter(self, adapter_id: str) -> Adapter: raise APIError(f"Failed to unload adapter: {e.response.text}") raise APIError(f"Failed to unload adapter: {str(e)}") + def update_adapter_project( + self, + adapter_id: str, + project_id: Optional[str], + *, + operation: str = "set", + ) -> Adapter: + """Update adapter project memberships.""" + try: + response = self.client.request( + "PATCH", + f"/rft/adapters/{adapter_id}/project", + json={"projectId": project_id, "operation": operation}, + ) + return Adapter.model_validate(response.get("adapter")) + except Exception as e: + if hasattr(e, "response") and hasattr(e.response, "text"): + raise APIError(f"Failed to update adapter project: {e.response.text}") + raise APIError(f"Failed to update adapter project: {str(e)}") + def get_deployable_models(self) -> List[str]: """Get list of base models that support LoRA deployment.""" try: diff --git a/packages/prime/src/prime_cli/api/projects.py b/packages/prime/src/prime_cli/api/projects.py new file mode 100644 index 00000000..32fa7af2 --- /dev/null +++ b/packages/prime/src/prime_cli/api/projects.py @@ -0,0 +1,99 @@ +"""Lab Projects API client.""" + +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, ConfigDict, Field + +from prime_cli.core import APIClient, APIError + + +class Project(BaseModel): + id: str + name: str + slug: str + description: Optional[str] = None + status: str + user_id: str = Field(alias="userId") + team_id: Optional[str] = Field(None, alias="teamId") + created_at: datetime = Field(alias="createdAt") + updated_at: datetime = Field(alias="updatedAt") + archived_at: Optional[datetime] = Field(None, alias="archivedAt") + + model_config = ConfigDict(populate_by_name=True) + + +class ProjectsClient: + def __init__(self, client: APIClient) -> None: + self.client = client + + def list( + self, + team_id: Optional[str] = None, + limit: int = 100, + offset: int = 0, + ) -> tuple[list[Project], int]: + params: dict[str, object] = { + "limit": limit, + "offset": offset, + } + if team_id: + params["teamId"] = team_id + + try: + response = self.client.get("/projects", params=params) + data = response.get("data", []) + total = int(response.get("totalCount", response.get("total_count", len(data)))) + return [Project.model_validate(item) for item in data], total + except Exception as exc: + raise APIError(f"Failed to list projects: {exc}") from exc + + def create( + self, + name: str, + slug: Optional[str] = None, + description: Optional[str] = None, + team_id: Optional[str] = None, + ) -> Project: + payload = { + "name": name, + "slug": slug, + "description": description, + "teamId": team_id, + } + payload = {key: value for key, value in payload.items() if value is not None} + + try: + response = self.client.post("/projects", json=payload) + return Project.model_validate(response["data"]) + except Exception as exc: + raise APIError(f"Failed to create project: {exc}") from exc + + def get(self, project_ref: str, team_id: Optional[str] = None) -> Project: + params = {"teamId": team_id} if team_id else None + try: + response = self.client.get(f"/projects/{project_ref}", params=params) + return Project.model_validate(response["data"]) + except Exception as exc: + raise APIError(f"Failed to get project: {exc}") from exc + + def update( + self, + project_ref: str, + name: Optional[str] = None, + slug: Optional[str] = None, + description: Optional[str] = None, + team_id: Optional[str] = None, + ) -> Project: + payload = { + "name": name, + "slug": slug, + "description": description, + } + payload = {key: value for key, value in payload.items() if value is not None} + params = {"teamId": team_id} if team_id else None + try: + response = self.client.patch(f"/projects/{project_ref}", json=payload, params=params) + return Project.model_validate(response["data"]) + except Exception as exc: + raise APIError(f"Failed to update project: {exc}") from exc diff --git a/packages/prime/src/prime_cli/api/rl.py b/packages/prime/src/prime_cli/api/rl.py index 28991739..1ade3099 100644 --- a/packages/prime/src/prime_cli/api/rl.py +++ b/packages/prime/src/prime_cli/api/rl.py @@ -66,6 +66,7 @@ class RLRun(BaseModel): name: Optional[str] = Field(None, description="Run name") user_id: str = Field(..., alias="userId") team_id: Optional[str] = Field(None, alias="teamId") + project_id: Optional[str] = Field(None, alias="projectId") cluster_id: Optional[str] = Field(None, alias="rftClusterId") status: str = Field(..., description="Run status") @@ -194,6 +195,7 @@ def create_run( wandb_run_name: Optional[str] = None, secrets: Optional[Dict[str, str]] = None, team_id: Optional[str] = None, + project_id: Optional[str] = None, eval_config: Optional[Dict[str, Any]] = None, val_config: Optional[Dict[str, Any]] = None, buffer_config: Optional[Dict[str, Any]] = None, @@ -252,6 +254,9 @@ def create_run( if team_id: payload["team_id"] = team_id + if project_id: + payload["project_id"] = project_id + if max_tokens: payload["max_tokens"] = max_tokens @@ -365,6 +370,35 @@ def restart_run(self, run_id: str) -> RLRun: raise APIError(f"Failed to restart Hosted Training run: {e.response.text}") raise APIError(f"Failed to restart Hosted Training run: {str(e)}") + def update_run_project( + self, + run_id: str, + project_id: Optional[str], + *, + operation: str = "set", + move_adapters: bool = True, + ) -> tuple[RLRun, int]: + """Update Hosted Training run project memberships.""" + try: + response = self.client.request( + "PATCH", + f"/rft/runs/{run_id}/project", + json={ + "projectId": project_id, + "operation": operation, + "moveAdapters": move_adapters, + }, + ) + run = RLRun.model_validate(response.get("run")) + adapters_updated = int( + response.get("adaptersUpdated", response.get("adapters_updated", 0)) + ) + return run, adapters_updated + except Exception as e: + if hasattr(e, "response") and hasattr(e.response, "text"): + raise APIError(f"Failed to update Hosted Training run project: {e.response.text}") + raise APIError(f"Failed to update Hosted Training run project: {str(e)}") + def list_checkpoints( self, run_id: str, status_filter: Optional[str] = None ) -> List[RLCheckpoint]: diff --git a/packages/prime/src/prime_cli/commands/evals.py b/packages/prime/src/prime_cli/commands/evals.py index 71cdea37..1878621a 100644 --- a/packages/prime/src/prime_cli/commands/evals.py +++ b/packages/prime/src/prime_cli/commands/evals.py @@ -31,6 +31,7 @@ clean_logs, get_new_log_lines, ) +from ..utils.projects import resolve_project_id from ..verifiers_bridge import ( DEFAULT_ENV_DIR_PATH, DEFAULT_MODEL, @@ -562,7 +563,9 @@ def _build_hosted_evaluation_payload(config: HostedEvalConfig) -> dict[str, Any] def _create_hosted_evaluations( - config: HostedEvalConfig, environment_ids: Optional[list[str]] = None + config: HostedEvalConfig, + environment_ids: Optional[list[str]] = None, + project_id: Optional[str] = None, ) -> dict[str, Any]: client = APIClient() payload = _build_hosted_evaluation_payload(config) @@ -572,6 +575,8 @@ def _create_hosted_evaluations( if client.config.team_id: payload["team_id"] = client.config.team_id + if project_id is not None: + payload["project_id"] = project_id created = client.post("/hosted-evaluations", json=payload) evaluation_id = created.get("evaluation_id") @@ -1044,6 +1049,7 @@ def _push_single_eval( eval_id: Optional[str], is_public: bool = False, name: Optional[str] = None, + project_id: Optional[str] = None, ) -> str: path = _validate_eval_path(config_path) eval_data = _load_eval_directory(path) @@ -1081,6 +1087,7 @@ def _push_single_eval( metadata=eval_data.get("metadata"), metrics=eval_data.get("metrics"), tags=eval_data.get("tags", []), + project_id=project_id, ) console.print(f"[green]✓ Updated evaluation:[/green] {eval_id}") except Exception as e: @@ -1099,6 +1106,7 @@ def _push_single_eval( metadata=eval_data.get("metadata"), metrics=eval_data.get("metrics"), tags=eval_data.get("tags", []), + project_id=project_id, is_public=is_public, ) @@ -1210,6 +1218,16 @@ def push_eval( "--public", help="Make the pushed evaluation public. Evaluations are private by default.", ), + project: Optional[str] = typer.Option( + None, + "--project", + help="Project ID or slug. Defaults to the active project for this workspace.", + ), + no_project: bool = typer.Option( + False, + "--no-project", + help="Do not attach this evaluation to the active project.", + ), ) -> None: """Push evaluation data to Prime Evals. @@ -1240,10 +1258,20 @@ def push_eval( console.print(" prime eval push outputs/evals/env--model/run-id --eval-id ") raise typer.Exit(1) + project_id = resolve_project_id(project, no_project=no_project) + if config_path is None: current_dir = Path(".") if _has_eval_files(current_dir): - result_eval_id = _push_single_eval(".", env_id, run_id, eval_id, is_public, name) + result_eval_id = _push_single_eval( + ".", + env_id, + run_id, + eval_id, + is_public, + name, + project_id, + ) if output == "json": console.print() output_data_as_json({"evaluation_id": result_eval_id}, console) @@ -1267,7 +1295,13 @@ def push_eval( for eval_dir in eval_dirs: try: result_eval_id = _push_single_eval( - str(eval_dir), env_id, run_id, eval_id, is_public, name + str(eval_dir), + env_id, + run_id, + eval_id, + is_public, + name, + project_id, ) results.append( {"path": str(eval_dir), "eval_id": result_eval_id, "status": "success"} @@ -1291,7 +1325,15 @@ def push_eval( return - result_eval_id = _push_single_eval(config_path, env_id, run_id, eval_id, is_public, name) + result_eval_id = _push_single_eval( + config_path, + env_id, + run_id, + eval_id, + is_public, + name, + project_id, + ) if output == "json": console.print() @@ -1459,6 +1501,16 @@ def run_eval_cmd( "--eval-name", help="Custom name for the hosted evaluation", ), + project: Optional[str] = typer.Option( + None, + "--project", + help="Project ID or slug. Defaults to the active project for this workspace.", + ), + no_project: bool = typer.Option( + False, + "--no-project", + help="Do not attach this evaluation to the active project.", + ), ) -> None: """Run an evaluation with local-first environment resolution.""" passthrough_args = list(ctx.args) @@ -1478,6 +1530,11 @@ def run_eval_cmd( raise typer.Exit(2) env_dir_path: Optional[str] = None + try: + project_id = resolve_project_id(project, no_project=no_project) + except APIError as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(1) from exc poll_interval_was_provided = ( ctx.get_parameter_source("poll_interval") == ParameterSource.COMMANDLINE ) @@ -1797,10 +1854,12 @@ def run_eval_cmd( api_base_url=target.get("api_base_url"), api_key_var=target.get("api_key_var"), ) - result = _create_hosted_evaluations( - hosted_config, - environment_ids=group["environment_ids"], - ) + hosted_kwargs: dict[str, Any] = { + "environment_ids": group["environment_ids"], + } + if project_id is not None: + hosted_kwargs["project_id"] = project_id + result = _create_hosted_evaluations(hosted_config, **hosted_kwargs) all_platform_slugs.extend(group["platform_slugs"]) all_evaluation_ids.extend(result.get("evaluation_ids") or [result["evaluation_id"]]) except APIError as exc: @@ -1840,4 +1899,6 @@ def run_eval_cmd( passthrough_args=local_passthrough_args, skip_upload=skip_upload, env_path=env_path, + project_id=project_id, + use_active_project=False, ) diff --git a/packages/prime/src/prime_cli/commands/projects.py b/packages/prime/src/prime_cli/commands/projects.py new file mode 100644 index 00000000..3fa07ad5 --- /dev/null +++ b/packages/prime/src/prime_cli/commands/projects.py @@ -0,0 +1,742 @@ +"""Lab Project commands.""" + +from typing import Optional + +import typer +from prime_evals import EvalsAPIError, EvalsClient +from rich.table import Table + +from prime_cli.api.deployments import DeploymentsClient +from prime_cli.api.projects import Project, ProjectsClient +from prime_cli.api.rl import RLClient +from prime_cli.core import Config + +from ..client import APIClient, APIError +from ..utils import ( + PlainTyper, + get_console, + json_output_help, + output_data_as_json, + validate_output_format, +) +from ..utils.projects import ( + clear_project_context, + ensure_active_project_scope, + get_active_project_id, + read_project_context, + write_project_context, +) + +console = get_console() + + +def _usage_help(*examples: tuple[str, str], json_help_text: Optional[str] = None) -> str: + lines = ["\b", "Examples:"] + for command, annotation in examples: + lines.append(f" {command}") + lines.append(f" {annotation}") + if json_help_text: + lines.append("") + lines.append("\b") + lines.extend(json_help_text.splitlines()) + return "\n".join(lines) + + +PROJECT_USAGE_HELP = _usage_help( + ( + 'prime project create "Alphabet Sort Baselines"', + "Create a project and make it active for the current workspace.", + ), + ( + "prime project current", + "Show the active project that new Lab runs and evals will use.", + ), + ( + "prime project use ", + "Switch this workspace to an existing project.", + ), + ( + "prime train rl.toml --no-project", + "Launch a run without attaching it to the active workspace project.", + ), +) + +app = PlainTyper( + help="Create and switch between Lab projects", + no_args_is_help=True, + epilog=PROJECT_USAGE_HELP, +) + +PROJECT_JSON_HELP = json_output_help( + ".project = {id, name, slug, description?, status, userId, teamId?, createdAt, updatedAt}", +) + +PROJECT_LIST_JSON_HELP = json_output_help( + ".projects[] = {id, name, slug, description?, status, userId, teamId?, createdAt, updatedAt}", + ".total_count = number", +) + +PROJECT_CREATE_HELP = _usage_help( + ( + 'prime project create "Alphabet Sort Baselines"', + "Create a project and set it as active for this workspace.", + ), + ( + 'prime project create "Alphabet Sort Baselines" --description "Baseline runs"', + "Create with a description shown in project details.", + ), + ( + 'prime project create "Team Project" --team-id --no-use', + "Create under a team without changing the active workspace project.", + ), + json_help_text=PROJECT_JSON_HELP, +) + +PROJECT_LIST_HELP = _usage_help( + ( + "prime project list", + "List active projects for the current personal or team context.", + ), + ( + "prime project list --limit 50 --offset 50", + "Page through active projects.", + ), + ( + "prime project list --output json", + "Print machine-readable project rows.", + ), + json_help_text=PROJECT_LIST_JSON_HELP, +) + +PROJECT_SHOW_HELP = _usage_help( + ( + "prime project show", + "Show the active workspace project.", + ), + ( + "prime project show ", + "Show a specific project by id.", + ), + ( + "prime project show --output json", + "Print project details as JSON.", + ), + json_help_text=PROJECT_JSON_HELP, +) + +PROJECT_USE_HELP = _usage_help( + ( + "prime project use ", + "Set the active project for this workspace.", + ), + ( + "prime switch ", + "Switch to a team before setting one of its projects as active.", + ), + ( + "prime project current", + "Confirm which project this workspace will use by default.", + ), + json_help_text=PROJECT_JSON_HELP, +) + +PROJECT_CURRENT_HELP = _usage_help( + ( + "prime project current", + "Show the active workspace project.", + ), + ( + "prime project current --output json", + "Print the active project, or null when none is set.", + ), + json_help_text=PROJECT_JSON_HELP, +) + +PROJECT_UPDATE_HELP = _usage_help( + ( + 'prime project update --description "Baseline alphabet sort runs"', + "Update the active project's description.", + ), + ( + 'prime project update --name "New Project Name"', + "Rename a specific project.", + ), + ( + "prime project update --clear-description", + "Clear the description field.", + ), + json_help_text=PROJECT_JSON_HELP, +) + +PROJECT_CLEAR_HELP = _usage_help( + ( + "prime project clear", + "Stop attaching new Lab runs and evals to a workspace project by default.", + ), +) + +PROJECT_ASSIGN_HELP = _usage_help( + ( + "prime project assign run ", + "Add a training run to the active project and move its adapters too.", + ), + ( + "prime project assign run --no-move-adapters", + "Add a run to a specific project without changing adapter project membership.", + ), + ( + "prime project assign eval ", + "Set an evaluation's project.", + ), + ( + "prime project assign adapter ", + "Add an adapter to a project.", + ), +) + +PROJECT_REMOVE_HELP = _usage_help( + ( + "prime project remove run ", + "Clear all project memberships from a training run and its adapters.", + ), + ( + "prime project remove run --no-move-adapters", + "Remove one project from a run without changing adapter project membership.", + ), + ( + "prime project remove eval ", + "Clear an evaluation's project.", + ), + ( + "prime project remove adapter ", + "Remove one project from an adapter.", + ), +) + + +def _project_payload(project: Project) -> dict: + return project.model_dump(mode="json", by_alias=True) + + +def _print_project(project: Project, *, active: bool = False) -> None: + table = Table(title="Active Project" if active else "Project") + table.add_column("Field", style="cyan") + table.add_column("Value", style="green") + table.add_row("Name", project.name) + table.add_row("Slug", project.slug) + table.add_row("ID", project.id) + table.add_row("Status", project.status) + table.add_row("Team", project.team_id or "Personal") + table.add_row("Description", project.description or "Not set") + table.add_row("Created", project.created_at.isoformat()) + table.add_row("Updated", project.updated_at.isoformat()) + if project.archived_at: + table.add_row("Archived", project.archived_at.isoformat()) + console.print(table) + + +def _active_project_ref_or_exit(config: Config) -> str: + active_project_id = get_active_project_id(config) + if not active_project_id: + console.print("[yellow]No active project for this workspace.[/yellow]") + raise typer.Exit(1) + return active_project_id + + +def _normalize_artifact_kind(kind: str) -> str: + normalized = kind.strip().lower().replace("_", "-") + if normalized in {"run", "training", "training-run", "rft-run"}: + return "training_run" + if normalized in {"eval", "evaluation"}: + return "evaluation" + if normalized in {"adapter", "deployment", "inference", "lora"}: + return "adapter" + + console.print("[red]Error:[/red] Kind must be one of run, eval, or adapter.") + raise typer.Exit(1) + + +def _assignment_payload( + *, + kind: str, + artifact_id: str, + project_id: Optional[str], + project_slug: Optional[str] = None, + adapters_updated: Optional[int] = None, +) -> dict: + payload = { + "artifact_type": kind, + "artifact_id": artifact_id, + "project_id": project_id, + "project_slug": project_slug, + } + if adapters_updated is not None: + payload["adapters_updated"] = adapters_updated + return payload + + +def _print_assignment_result(payload: dict, *, removed: bool = False) -> None: + table = Table(title="Project Assignment") + table.add_column("Field", style="cyan") + table.add_column("Value", style="green") + table.add_row("Artifact Type", str(payload["artifact_type"])) + table.add_row("Artifact ID", str(payload["artifact_id"])) + table.add_row("Project", str(payload["project_slug"] or payload["project_id"] or "None")) + if payload.get("adapters_updated") is not None: + table.add_row("Adapters Updated", str(payload["adapters_updated"])) + + console.print( + "[green]✓ Project removed[/green]" if removed else "[green]✓ Project assigned[/green]" + ) + console.print(table) + + +@app.command("create", epilog=PROJECT_CREATE_HELP) +def create_project( + name: str = typer.Argument(..., help="Project display name"), + slug: Optional[str] = typer.Option(None, "--slug", help="Stable project slug"), + description: Optional[str] = typer.Option( + None, + "--description", + "-d", + help="Project description", + ), + team_id: Optional[str] = typer.Option( + None, + "--team-id", + help="Team ID. Defaults to the active CLI team, if any.", + ), + use_project: bool = typer.Option( + True, + "--use/--no-use", + help="Set the created project as the active project for this workspace.", + ), + output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), +) -> None: + """Create a Lab project.""" + validate_output_format(output, console) + config = Config() + resolved_team_id = team_id if team_id is not None else config.team_id + + try: + if use_project: + ensure_active_project_scope( + resolved_team_id, + config, + action="create and set an active project", + guidance="Use --no-use for one-off team project creation.", + ) + + project = ProjectsClient(APIClient()).create( + name=name, + slug=slug, + description=description, + team_id=resolved_team_id, + ) + if use_project: + write_project_context(project, config) + + if output == "json": + output_data_as_json({"project": _project_payload(project)}, console) + return + + console.print("[green]✓ Project created[/green]") + _print_project(project, active=use_project) + except APIError as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(1) + + +@app.command("list", epilog=PROJECT_LIST_HELP) +def list_projects( + team_id: Optional[str] = typer.Option( + None, + "--team-id", + help="Team ID. Defaults to the active CLI team, if any.", + ), + limit: int = typer.Option(100, "--limit", help="Maximum number of projects to list"), + offset: int = typer.Option(0, "--offset", help="Number of projects to skip"), + output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), +) -> None: + """List Lab projects for the current workspace.""" + validate_output_format(output, console) + config = Config() + resolved_team_id = team_id if team_id is not None else config.team_id + + try: + projects, total = ProjectsClient(APIClient()).list( + team_id=resolved_team_id, + limit=limit, + offset=offset, + ) + if output == "json": + output_data_as_json( + { + "projects": [_project_payload(project) for project in projects], + "total_count": total, + "offset": offset, + "limit": limit, + }, + console, + ) + return + + active_project_id = get_active_project_id(config) + table = Table(title=f"Projects (Total: {total})") + table.add_column("", style="green", no_wrap=True) + table.add_column("Name", style="blue") + table.add_column("Slug", style="green") + table.add_column("ID", style="cyan", no_wrap=True) + table.add_column("Status", style="yellow") + table.add_column("Updated", style="magenta") + for project in projects: + table.add_row( + "*" if project.id == active_project_id else "", + project.name, + project.slug, + project.id, + project.status, + project.updated_at.isoformat(), + ) + console.print(table) + except APIError as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(1) + + +@app.command("show", epilog=PROJECT_SHOW_HELP) +def show_project( + project_ref: Optional[str] = typer.Argument( + None, + help="Project ID or slug. Defaults to the active project for this workspace.", + ), + team_id: Optional[str] = typer.Option( + None, + "--team-id", + help="Team ID. Defaults to the active CLI team, if any.", + ), + output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), +) -> None: + """Show a Lab project and its configured fields.""" + validate_output_format(output, console) + config = Config() + resolved_team_id = team_id if team_id is not None else config.team_id + resolved_project_ref = project_ref or _active_project_ref_or_exit(config) + + try: + project = ProjectsClient(APIClient()).get( + resolved_project_ref, + team_id=resolved_team_id, + ) + + if output == "json": + output_data_as_json({"project": _project_payload(project)}, console) + return + + _print_project(project, active=project.id == get_active_project_id(config)) + except APIError as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(1) + + +@app.command("use", epilog=PROJECT_USE_HELP) +def use_project( + project_ref: str = typer.Argument(..., help="Project ID or slug"), + team_id: Optional[str] = typer.Option( + None, + "--team-id", + help="Team ID. Defaults to the active CLI team, if any.", + ), + output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), +) -> None: + """Set the active Lab project for this workspace.""" + validate_output_format(output, console) + config = Config() + resolved_team_id = team_id if team_id is not None else config.team_id + + try: + project = ProjectsClient(APIClient()).get(project_ref, team_id=resolved_team_id) + ensure_active_project_scope( + project.team_id, + config, + action="set an active project", + ) + write_project_context(project, config) + + if output == "json": + output_data_as_json({"project": _project_payload(project)}, console) + return + + console.print("[green]✓ Active project updated[/green]") + _print_project(project, active=True) + except APIError as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(1) + + +@app.command("current", epilog=PROJECT_CURRENT_HELP) +def current_project( + output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), +) -> None: + """Show the active Lab project for this workspace.""" + validate_output_format(output, console) + config = Config() + active_project_id = get_active_project_id(config) + context = read_project_context() + + if not active_project_id: + if output == "json": + output_data_as_json({"project": None}, console) + return + console.print("[yellow]No active project for this workspace.[/yellow]") + return + + try: + project = ProjectsClient(APIClient()).get(active_project_id, team_id=config.team_id) + if output == "json": + output_data_as_json({"project": _project_payload(project)}, console) + return + _print_project(project, active=True) + except APIError: + if output == "json": + output_data_as_json( + { + "project": None, + "context": context or {"project_id": active_project_id}, + }, + console, + ) + return + console.print(f"[yellow]Active project:[/yellow] {active_project_id}") + if context.get("project_slug"): + console.print(f"[dim]Slug:[/dim] {context['project_slug']}") + console.print("[dim]Could not fetch current project details from the API.[/dim]") + + +@app.command("update", epilog=PROJECT_UPDATE_HELP) +def update_project( + project_ref: Optional[str] = typer.Argument( + None, + help="Project ID or slug. Defaults to the active project for this workspace.", + ), + name: Optional[str] = typer.Option(None, "--name", help="Project display name"), + slug: Optional[str] = typer.Option(None, "--slug", help="Stable project slug"), + description: Optional[str] = typer.Option( + None, + "--description", + "-d", + help="Project description", + ), + clear_description: bool = typer.Option( + False, + "--clear-description", + help="Clear the project description.", + ), + team_id: Optional[str] = typer.Option( + None, + "--team-id", + help="Team ID. Defaults to the active CLI team, if any.", + ), + output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), +) -> None: + """Update a Lab project's user-editable fields.""" + validate_output_format(output, console) + config = Config() + resolved_team_id = team_id if team_id is not None else config.team_id + resolved_project_ref = project_ref or _active_project_ref_or_exit(config) + + if description is not None and clear_description: + console.print( + "[red]Error:[/red] Use either --description or --clear-description, not both." + ) + raise typer.Exit(1) + + resolved_description = "" if clear_description else description + + if name is None and slug is None and resolved_description is None: + console.print( + "[red]Error:[/red] Provide --name, --slug, --description, or --clear-description." + ) + raise typer.Exit(1) + + try: + project = ProjectsClient(APIClient()).update( + resolved_project_ref, + name=name, + slug=slug, + description=resolved_description, + team_id=resolved_team_id, + ) + + if get_active_project_id(config) == project.id: + write_project_context(project, config) + + if output == "json": + output_data_as_json({"project": _project_payload(project)}, console) + return + + console.print("[green]✓ Project updated[/green]") + _print_project(project, active=project.id == get_active_project_id(config)) + except APIError as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(1) + + +@app.command("clear", epilog=PROJECT_CLEAR_HELP) +def clear_project() -> None: + """Clear the active Lab project for this workspace.""" + if clear_project_context(): + console.print("[green]✓ Active project cleared[/green]") + else: + console.print("[yellow]No active project was set.[/yellow]") + + +@app.command("assign", epilog=PROJECT_ASSIGN_HELP) +def assign_artifact_to_project( + kind: str = typer.Argument(..., help="Artifact kind: run, eval, or adapter"), + artifact_id: str = typer.Argument(..., help="Training run, evaluation, or adapter ID"), + project_ref: Optional[str] = typer.Argument( + None, + help="Project ID or slug. Defaults to the active project for this workspace.", + ), + team_id: Optional[str] = typer.Option( + None, + "--team-id", + help="Team ID. Defaults to the active CLI team, if any.", + ), + move_adapters: bool = typer.Option( + True, + "--move-adapters/--no-move-adapters", + help="For training runs, also add adapters created by the run to the project.", + ), + output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), +) -> None: + """Add a training run, evaluation, or adapter to a Lab project.""" + validate_output_format(output, console) + config = Config() + api_client = APIClient() + resolved_team_id = team_id if team_id is not None else config.team_id + resolved_project_ref = project_ref or _active_project_ref_or_exit(config) + normalized_kind = _normalize_artifact_kind(kind) + + try: + project = ProjectsClient(api_client).get( + resolved_project_ref, + team_id=resolved_team_id, + ) + + adapters_updated: Optional[int] = None + if normalized_kind == "training_run": + _, adapters_updated = RLClient(api_client).update_run_project( + artifact_id, + project.id, + operation="add", + move_adapters=move_adapters, + ) + elif normalized_kind == "evaluation": + EvalsClient(api_client).update_evaluation( + artifact_id, + project_id=project.id, + ) + else: + DeploymentsClient(api_client).update_adapter_project( + artifact_id, + project.id, + operation="add", + ) + + payload = _assignment_payload( + kind=normalized_kind, + artifact_id=artifact_id, + project_id=project.id, + project_slug=project.slug, + adapters_updated=adapters_updated, + ) + if output == "json": + output_data_as_json(payload, console) + return + + _print_assignment_result(payload) + except (APIError, EvalsAPIError) as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(1) + + +@app.command("remove", epilog=PROJECT_REMOVE_HELP) +def remove_artifact_from_project( + kind: str = typer.Argument(..., help="Artifact kind: run, eval, or adapter"), + artifact_id: str = typer.Argument(..., help="Training run, evaluation, or adapter ID"), + project_ref: Optional[str] = typer.Argument( + None, + help="Project ID or slug to remove. Omit to remove all project memberships.", + ), + team_id: Optional[str] = typer.Option( + None, + "--team-id", + help="Team ID. Defaults to the active CLI team, if any.", + ), + move_adapters: bool = typer.Option( + True, + "--move-adapters/--no-move-adapters", + help="For training runs, also remove adapters created by the run from the project.", + ), + output: str = typer.Option("table", "--output", "-o", help="Output format: table or json"), +) -> None: + """Remove a training run, evaluation, or adapter from a Lab project.""" + validate_output_format(output, console) + config = Config() + api_client = APIClient() + resolved_team_id = team_id if team_id is not None else config.team_id + normalized_kind = _normalize_artifact_kind(kind) + + if normalized_kind == "evaluation" and project_ref: + console.print( + "[red]Error:[/red] Evaluation project removal clears the evaluation's project. " + "Targeted removal from one project is not supported for evaluations." + ) + console.print("[dim]Omit the project argument to clear the evaluation project.[/dim]") + raise typer.Exit(1) + + try: + project: Optional[Project] = None + if project_ref: + project = ProjectsClient(api_client).get( + project_ref, + team_id=resolved_team_id, + ) + + adapters_updated: Optional[int] = None + if normalized_kind == "training_run": + _, adapters_updated = RLClient(api_client).update_run_project( + artifact_id, + project.id if project else None, + operation="remove" if project else "clear", + move_adapters=move_adapters, + ) + elif normalized_kind == "evaluation": + EvalsClient(api_client).update_evaluation( + artifact_id, + clear_project=True, + ) + else: + DeploymentsClient(api_client).update_adapter_project( + artifact_id, + project.id if project else None, + operation="remove" if project else "clear", + ) + + payload = _assignment_payload( + kind=normalized_kind, + artifact_id=artifact_id, + project_id=project.id if project else None, + project_slug=project.slug if project else None, + adapters_updated=adapters_updated, + ) + if output == "json": + output_data_as_json(payload, console) + return + + _print_assignment_result(payload, removed=True) + except (APIError, EvalsAPIError) as exc: + console.print(f"[red]Error:[/red] {exc}") + raise typer.Exit(1) diff --git a/packages/prime/src/prime_cli/commands/rl.py b/packages/prime/src/prime_cli/commands/rl.py index 7bcdd7c2..a4482784 100644 --- a/packages/prime/src/prime_cli/commands/rl.py +++ b/packages/prime/src/prime_cli/commands/rl.py @@ -36,6 +36,7 @@ format_promo_price, strip_ansi, ) +from ..utils.projects import resolve_project_id from ..utils.prompt import confirm_or_skip from .feedback import submit_feedback from .usage import RUN_USAGE_JSON_HELP, run_usage_command @@ -798,6 +799,7 @@ def _format_run_for_display(run: RLRun) -> Dict[str, Any]: "rollouts": str(run.rollouts_per_example), "created_at": created_at, "team_id": run.team_id, + "project_id": run.project_id, } @@ -878,6 +880,16 @@ def create_run( "--skip-action-check", help="Skip action status check and run even if environment action failed.", ), + project: Optional[str] = typer.Option( + None, + "--project", + help="Project ID or slug. Defaults to the active project for this workspace.", + ), + no_project: bool = typer.Option( + False, + "--no-project", + help="Do not attach this run to the active project.", + ), yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"), ) -> None: """Launch a Hosted Training run from a config file. @@ -932,6 +944,12 @@ def warn(msg: str) -> None: api_client = APIClient() rl_client = RLClient(api_client) app_config = Config() + project_id = resolve_project_id( + project, + no_project=no_project, + config=app_config, + client=api_client, + ) # Kick off pricing fetch in the background so it overlaps with summary # rendering and the action-status checks below. Daemon thread so a slow @@ -964,6 +982,8 @@ def _fetch_pricing() -> None: console.print(f" Environments: {', '.join(e.id for e in cfg.env)}") if app_config.team_id: console.print(f" Team: {app_config.team_id}") + if project_id: + console.print(f" Project: {project_id}") # Training console.print("\n[cyan]Training[/cyan]") @@ -1149,6 +1169,7 @@ def _format(list_p: Any, eff_p: Any) -> str: wandb_run_name=cfg.wandb.name, secrets=secrets if secrets else None, team_id=app_config.team_id, + project_id=project_id, eval_config=cfg.eval.to_api_dict(), val_config=cfg.val.to_api_dict(), buffer_config=cfg.buffer.to_api_dict(), diff --git a/packages/prime/src/prime_cli/lab_setup.py b/packages/prime/src/prime_cli/lab_setup.py index c0874d80..9513382f 100644 --- a/packages/prime/src/prime_cli/lab_setup.py +++ b/packages/prime/src/prime_cli/lab_setup.py @@ -6,6 +6,7 @@ import hashlib import json import os +import re import shutil import subprocess import sys @@ -26,6 +27,8 @@ from rich.table import Table from rich.text import Text +from .api.projects import ProjectsClient +from .core import APIClient, APIError, Config from .lab_agents import ( agent_capability, agent_project_skills_dirs, @@ -40,6 +43,11 @@ run_lab_hygiene_preflight, tracked_lab_hygiene_paths, ) +from .utils.projects import ( + ensure_active_project_scope, + get_active_project_id, + write_project_context, +) VERIFIERS_REPO = "primeintellect-ai/verifiers" VERIFIERS_REF = "f43e42c1fabfe2604afc95b9ce62779a8f55d487" @@ -93,6 +101,9 @@ class LabSetupOptions: skip_agents_md: bool = False skip_install: bool = False agents: tuple[str, ...] = () + no_project: bool = False + project_ref: str | None = None + project_name: str | None = None @dataclass(frozen=True) @@ -245,7 +256,27 @@ def parse_lab_setup_args(args: list[str]) -> LabSetupOptions: action="store_true", help="Use setup defaults without prompts.", ) + parser.add_argument( + "--no-project", + action="store_true", + help="Skip creating or selecting a Lab project during setup.", + ) + parser.add_argument( + "--project", + dest="project_ref", + help="Project ID or slug to make active instead of creating a default project.", + ) + parser.add_argument( + "--project-name", + help="Name for the default project created during setup.", + ) namespace = parser.parse_args(args) + if namespace.no_project and namespace.project_ref: + raise ValueError("--project and --no-project cannot be used together.") + if namespace.no_project and namespace.project_name: + raise ValueError("--project-name and --no-project cannot be used together.") + if namespace.project_ref and namespace.project_name: + raise ValueError("--project and --project-name cannot be used together.") return LabSetupOptions( skip_agents_md=bool(namespace.skip_agents_md), skip_install=bool(namespace.skip_install), @@ -253,6 +284,9 @@ def parse_lab_setup_args(args: list[str]) -> LabSetupOptions: namespace.agents, no_interactive=bool(namespace.no_interactive), ), + no_project=bool(namespace.no_project), + project_ref=namespace.project_ref, + project_name=namespace.project_name, ) @@ -378,6 +412,7 @@ def _run_lab_setup_steps( _report_missing_agent_requirements(options.agents, emit) _prepare_agent_native_surfaces(workspace, options.agents, emit) _sync_lab_metadata(workspace, options.agents, setup_source="prime lab setup") + _ensure_setup_project(options, workspace=workspace, emit=emit) if not options.skip_agents_md: _sync_workspace_guidance(workspace, options.agents, emit, force=True) @@ -1390,6 +1425,64 @@ def _ensure_uv_project(workspace: Path, emit: Emit, runner: Runner) -> None: _check_command(["uv", "add", "verifiers"], workspace, emit, runner) +def _default_project_name(workspace: Path) -> str: + parts = [part for part in re.split(r"[^A-Za-z0-9]+", workspace.name) if part] + if not parts: + return "Lab Project" + return " ".join(f"{part[:1].upper()}{part[1:]}" for part in parts) + + +def _ensure_setup_project( + options: LabSetupOptions, + *, + workspace: Path, + emit: Emit, +) -> None: + if options.no_project: + emit("Skipped Lab project setup (--no-project)\n") + return + + config = Config() + if not config.api_key: + emit( + "Skipped Lab project setup because no API key is configured; " + "run prime login, then prime project create later.\n" + ) + return + + explicit_project_requested = options.project_ref is not None or options.project_name is not None + active_project_id = get_active_project_id(config, workspace=workspace) + if active_project_id and not explicit_project_requested: + emit(f"Using active Lab project {active_project_id}\n") + return + + projects_client = ProjectsClient(APIClient()) + try: + if options.project_ref: + project = projects_client.get(options.project_ref, team_id=config.team_id) + ensure_active_project_scope( + project.team_id, + config, + action="set an active project", + ) + write_project_context(project, config, workspace=workspace) + emit(f"Using Lab project {project.name} ({project.slug})\n") + return + + project_name = options.project_name or _default_project_name(workspace) + project = projects_client.create(name=project_name, team_id=config.team_id) + ensure_active_project_scope( + project.team_id, + config, + action="create and set an active project", + ) + write_project_context(project, config, workspace=workspace) + emit(f"Created Lab project {project.name} ({project.slug})\n") + except APIError as exc: + emit(f"Skipped Lab project setup because the Projects API request failed: {exc}\n") + emit("Run prime project create later to attach new Lab runs and evals by default.\n") + + def _post_setup_call_to_action(options: LabSetupOptions) -> Panel: primary_agent = options.agents[0] if options.agents else "your coding agent" prompt_heading = f"ask {primary_agent}" diff --git a/packages/prime/src/prime_cli/main.py b/packages/prime/src/prime_cli/main.py index 39379578..bea27300 100644 --- a/packages/prime/src/prime_cli/main.py +++ b/packages/prime/src/prime_cli/main.py @@ -20,6 +20,7 @@ from .commands.login import app as login_app from .commands.logout import app as logout_app from .commands.pods import app as pods_app +from .commands.projects import app as projects_app from .commands.registry import app as registry_app from .commands.rl import app as train_app from .commands.sandbox import app as sandbox_app @@ -47,6 +48,7 @@ app.command("fork", rich_help_panel="Lab", epilog=FORK_JSON_HELP)(fork_command) app.add_typer(evals_app, name="eval", rich_help_panel="Lab") app.add_typer(gepa_app, name="gepa", rich_help_panel="Lab") +app.add_typer(projects_app, name="project", rich_help_panel="Lab") app.add_typer(train_app, name="train", rich_help_panel="Lab") app.add_typer( train_app, diff --git a/packages/prime/src/prime_cli/utils/eval_push.py b/packages/prime/src/prime_cli/utils/eval_push.py index 5af038cc..78235368 100644 --- a/packages/prime/src/prime_cli/utils/eval_push.py +++ b/packages/prime/src/prime_cli/utils/eval_push.py @@ -10,6 +10,7 @@ from .display import get_eval_viewer_url from .env_metadata import find_environment_metadata from .plain import get_console +from .projects import resolve_project_id console = get_console() @@ -57,6 +58,8 @@ def push_eval_results_to_hub( job_id: str, env_path: Optional[Path] = None, upstream_slug: Optional[str] = None, + project_id: Optional[str] = None, + use_active_project: bool = True, ) -> None: """ Push evaluation results to Prime Evals Hub after `prime eval run` completes. @@ -191,6 +194,8 @@ def push_eval_results_to_hub( eval_name = f"{env_name}--{model}--{datetime.now().strftime('%Y%m%d_%H%M%S')}" evals_client = EvalsClient(api_client) + if project_id is None and use_active_project: + project_id = resolve_project_id(None, client=api_client) create_response = evals_client.create_evaluation( name=eval_name, @@ -201,6 +206,7 @@ def push_eval_results_to_hub( task_type=metadata.get("task_type"), metadata=eval_metadata, metrics=metrics, + project_id=project_id, is_public=False, # Private by default - only visible to the user who created it ) diff --git a/packages/prime/src/prime_cli/utils/projects.py b/packages/prime/src/prime_cli/utils/projects.py new file mode 100644 index 00000000..96bb0fe7 --- /dev/null +++ b/packages/prime/src/prime_cli/utils/projects.py @@ -0,0 +1,161 @@ +"""Local active-project context for Lab workflows.""" + +import json +import os +from pathlib import Path +from typing import Optional + +from prime_cli.api.projects import Project, ProjectsClient +from prime_cli.core import APIClient, APIError, Config + +PROJECT_CONTEXT_ENV = "PRIME_PROJECT_ID" + + +def project_context_path(workspace: Optional[Path] = None) -> Path: + root = (workspace or Path.cwd()).resolve() + return root / ".prime" / "lab" / "context.json" + + +def _find_lab_workspace_root(workspace: Optional[Path] = None) -> Optional[Path]: + root = (workspace or Path.cwd()).resolve() + for candidate in (root, *root.parents): + if (candidate / ".prime" / "lab.json").exists(): + return candidate + return None + + +def _find_project_context_path(workspace: Optional[Path] = None) -> Optional[Path]: + root = (workspace or Path.cwd()).resolve() + lab_root = _find_lab_workspace_root(root) + if lab_root: + path = project_context_path(lab_root) + return path if path.exists() else None + + path = project_context_path(root) + return path if path.exists() else None + + +def _write_project_context_path(workspace: Optional[Path] = None) -> Path: + existing_path = _find_project_context_path(workspace) + if existing_path: + return existing_path + + lab_root = _find_lab_workspace_root(workspace) + if lab_root: + return project_context_path(lab_root) + + root = (workspace or Path.cwd()).resolve() + return project_context_path(root) + + +def read_project_context(workspace: Optional[Path] = None) -> dict: + path = _find_project_context_path(workspace) + if path is None: + return {} + if not path.exists(): + return {} + try: + data = json.loads(path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + return {} + return data if isinstance(data, dict) else {} + + +def scope_label(team_id: Optional[str]) -> str: + return f"team {team_id}" if team_id else "personal account" + + +def ensure_active_project_scope( + project_team_id: Optional[str], + config: Config, + *, + action: str, + guidance: Optional[str] = None, +) -> None: + if project_team_id == config.team_id: + return + + message = ( + f"Cannot {action} for {scope_label(project_team_id)} while the CLI is " + f"using {scope_label(config.team_id)}. Switch account context first with " + "'prime switch '." + ) + if guidance: + message = f"{message} {guidance}" + + raise APIError(message) + + +def write_project_context( + project: Project, + config: Optional[Config] = None, + workspace: Optional[Path] = None, +) -> None: + config = config or Config() + path = _write_project_context_path(workspace) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps( + { + "project_id": project.id, + "project_slug": project.slug, + "project_name": project.name, + "team_id": project.team_id, + "base_url": config.base_url, + }, + indent=2, + ) + + "\n", + encoding="utf-8", + ) + + +def clear_project_context(workspace: Optional[Path] = None) -> bool: + path = _find_project_context_path(workspace) + if path is None: + return False + path.unlink() + return True + + +def get_active_project_id( + config: Optional[Config] = None, + workspace: Optional[Path] = None, +) -> Optional[str]: + env_project_id = os.getenv(PROJECT_CONTEXT_ENV) + if env_project_id and env_project_id.strip(): + return env_project_id.strip() + + config = config or Config() + context = read_project_context(workspace) + if not context: + return None + + if context.get("base_url") and context.get("base_url") != config.base_url: + return None + if context.get("team_id") != config.team_id: + return None + + project_id = context.get("project_id") + return str(project_id) if project_id else None + + +def resolve_project_id( + project_ref: Optional[str], + *, + no_project: bool = False, + config: Optional[Config] = None, + client: Optional[APIClient] = None, +) -> Optional[str]: + if project_ref and no_project: + raise APIError("Cannot use --project and --no-project together.") + + if no_project: + return None + + config = config or Config() + if project_ref: + projects_client = ProjectsClient(client or APIClient()) + return projects_client.get(project_ref, team_id=config.team_id).id + + return get_active_project_id(config) diff --git a/packages/prime/src/prime_cli/verifiers_bridge.py b/packages/prime/src/prime_cli/verifiers_bridge.py index fe01b9aa..9edc1a1d 100644 --- a/packages/prime/src/prime_cli/verifiers_bridge.py +++ b/packages/prime/src/prime_cli/verifiers_bridge.py @@ -141,6 +141,8 @@ def _append_eval_options(help_text: str) -> str: "Allow tunnel creation and management for hosted evaluations.", " --custom-secrets JSON Custom sandbox secrets for hosted evaluations.", " --eval-name TEXT Custom name for the hosted evaluation.", + " --project TEXT Project ID or slug. Defaults to the active project.", + " --no-project Do not attach this evaluation to the active project.", ] lines = help_text.rstrip("\n").splitlines() for extra_line in extra_lines: @@ -947,6 +949,8 @@ def run_eval_passthrough( *, skip_upload: bool, env_path: Optional[str], + project_id: Optional[str] = None, + use_active_project: bool = True, ) -> None: plugin = load_verifiers_prime_plugin(console=console) config = Config() @@ -1054,6 +1058,8 @@ def run_eval_passthrough( job_id=job_id, env_path=Path(env_path) if env_path else None, upstream_slug=upstream_slug, + project_id=project_id, + use_active_project=use_active_project, ) except Exception as exc: console.print(f"[red]Failed to push results to hub:[/red] {exc}") diff --git a/packages/prime/tests/test_deployments.py b/packages/prime/tests/test_deployments.py index 759e5de3..69883860 100644 --- a/packages/prime/tests/test_deployments.py +++ b/packages/prime/tests/test_deployments.py @@ -1,6 +1,7 @@ from types import SimpleNamespace from typing import Any +from prime_cli.api.deployments import DeploymentsClient from prime_cli.main import app from prime_cli.utils import strip_ansi from typer.testing import CliRunner @@ -57,3 +58,48 @@ def deploy_adapter(self, model_id: str) -> Any: assert "export PRIME_API_KEY=" in output assert "PRIME_API_KEY" in output assert "curl -X POST" in output + + +def test_update_adapter_project_sends_backend_payload_shape() -> None: + class FakeAPIClient: + def __init__(self) -> None: + self.requests: list[tuple[str, str, dict[str, Any] | None]] = [] + + def request( + self, + method: str, + endpoint: str, + json: dict[str, Any] | None = None, + ) -> dict[str, Any]: + self.requests.append((method, endpoint, json)) + return { + "adapter": { + "id": "adapter-123", + "userId": "user-1", + "projectId": json.get("projectId") if json else None, + "rftRunId": "run-123", + "baseModel": "Qwen/Qwen3.5-0.8B", + "status": "READY", + "deploymentStatus": "NOT_DEPLOYED", + "createdAt": "2026-05-17T00:00:00Z", + "updatedAt": "2026-05-17T00:00:00Z", + } + } + + api_client = FakeAPIClient() + client = DeploymentsClient(api_client) # type: ignore[arg-type] + + adapter = client.update_adapter_project( + "adapter-123", + "project-123", + operation="add", + ) + + assert api_client.requests == [ + ( + "PATCH", + "/rft/adapters/adapter-123/project", + {"projectId": "project-123", "operation": "add"}, + ) + ] + assert adapter.project_id == "project-123" diff --git a/packages/prime/tests/test_eval_help.py b/packages/prime/tests/test_eval_help.py index 3f84969d..50b3c923 100644 --- a/packages/prime/tests/test_eval_help.py +++ b/packages/prime/tests/test_eval_help.py @@ -46,6 +46,8 @@ def test_append_eval_options_mentions_tunnel_access(): help_text = _append_eval_options("Usage: prime eval run [-h] environment\n") assert "--allow-tunnel-access" in help_text + assert "--project TEXT" in help_text + assert "--no-project" in help_text def test_eval_view_uses_prime_viewer(monkeypatch): diff --git a/packages/prime/tests/test_eval_push.py b/packages/prime/tests/test_eval_push.py index 6a9723f9..c0082a13 100644 --- a/packages/prime/tests/test_eval_push.py +++ b/packages/prime/tests/test_eval_push.py @@ -9,12 +9,59 @@ _validate_eval_path, ) from prime_cli.main import app +from prime_cli.utils.eval_push import push_eval_results_to_hub from typer.testing import CliRunner from typing_extensions import cast runner = CliRunner() +def test_push_eval_results_no_project_does_not_resolve_active_project(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + output_dir = tmp_path / "outputs" / "evals" / "simpleqa--openai--gpt-4" / "20260520_120000" + output_dir.mkdir(parents=True) + (output_dir / "metadata.json").write_text(json.dumps({"task_type": "qa"})) + (output_dir / "results.jsonl").write_text("") + + captured = {} + + class DummyAPIClient: + def get(self, _endpoint): + return {"data": {"id": "env-123"}} + + class DummyEvalsClient: + def __init__(self, _api_client): + pass + + def create_evaluation(self, **kwargs): + captured.update(kwargs) + return {"evaluation_id": "eval-123"} + + def finalize_evaluation(self, evaluation_id, metrics=None): + captured["finalized_evaluation_id"] = evaluation_id + captured["finalized_metrics"] = metrics + + monkeypatch.setattr("prime_cli.utils.eval_push.APIClient", DummyAPIClient) + monkeypatch.setattr("prime_cli.utils.eval_push.EvalsClient", DummyEvalsClient) + monkeypatch.setattr( + "prime_cli.utils.eval_push.resolve_project_id", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + AssertionError("should not resolve active project") + ), + ) + + push_eval_results_to_hub( + env_name="simpleqa", + model="openai/gpt-4", + job_id="job-123", + upstream_slug="owner/simpleqa", + use_active_project=False, + ) + + assert captured["project_id"] is None + assert captured["finalized_evaluation_id"] == "eval-123" + + class TestHasEvalFiles: """Tests for _has_eval_files function""" @@ -150,7 +197,9 @@ def test_push_eval_forwards_name_override(monkeypatch, tmp_path): captured = {} - def fake_push_single_eval(config_path, env_slug, run_id, eval_id, is_public, name): + def fake_push_single_eval( + config_path, env_slug, run_id, eval_id, is_public, name, project_id=None + ): captured.update( { "config_path": config_path, @@ -159,6 +208,7 @@ def fake_push_single_eval(config_path, env_slug, run_id, eval_id, is_public, nam "eval_id": eval_id, "is_public": is_public, "name": name, + "project_id": project_id, } ) return "eval-123" @@ -182,6 +232,56 @@ def fake_push_single_eval(config_path, env_slug, run_id, eval_id, is_public, nam "eval_id": "eval-123", "is_public": False, "name": "custom eval", + "project_id": None, + } + + +def test_push_eval_forwards_resolved_project_id(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr( + "prime_cli.commands.evals.resolve_project_id", + lambda project, no_project=False: "project-123", + ) + + captured = {} + + def fake_push_single_eval( + config_path, env_slug, run_id, eval_id, is_public, name, project_id=None + ): + captured.update( + { + "config_path": config_path, + "env_slug": env_slug, + "run_id": run_id, + "eval_id": eval_id, + "is_public": is_public, + "name": name, + "project_id": project_id, + } + ) + return "eval-123" + + monkeypatch.setattr("prime_cli.commands.evals._push_single_eval", fake_push_single_eval) + + (tmp_path / "metadata.json").write_text(json.dumps({"env": "gsm8k", "model": "gpt-4"})) + (tmp_path / "results.jsonl").write_text("") + + result = runner.invoke( + app, + ["eval", "push", ".", "--project", "project-slug"], + env={"PRIME_DISABLE_VERSION_CHECK": "1"}, + ) + + assert result.exit_code == 0, result.output + assert captured == { + "config_path": ".", + "env_slug": None, + "run_id": None, + "eval_id": None, + "is_public": False, + "name": None, + "project_id": "project-123", } diff --git a/packages/prime/tests/test_hosted_eval.py b/packages/prime/tests/test_hosted_eval.py index fa329a8d..11972002 100644 --- a/packages/prime/tests/test_hosted_eval.py +++ b/packages/prime/tests/test_hosted_eval.py @@ -22,6 +22,13 @@ runner = CliRunner() +@pytest.fixture(autouse=True) +def isolated_project_context(monkeypatch, tmp_path): + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.delenv("PRIME_PROJECT_ID", raising=False) + + class TestLogCleaning: def test_strip_ansi_basic(self): assert strip_ansi("\x1b[31mRed text\x1b[0m") == "Red text" @@ -355,6 +362,71 @@ def post(self, endpoint, json=None): assert captured["json"]["team_id"] == "team-123" +def test_create_hosted_evaluation_adds_project_id_to_payload(monkeypatch): + captured = {} + + class DummyConfig: + team_id = None + + class DummyAPIClient: + def __init__(self): + self.config = DummyConfig() + + def post(self, endpoint, json=None): + captured["endpoint"] = endpoint + captured["json"] = json + return {"evaluation_id": "eval-123"} + + monkeypatch.setattr("prime_cli.commands.evals.APIClient", DummyAPIClient) + + result = _create_hosted_evaluations( + HostedEvalConfig( + environment_id="env-123", + inference_model="openai/gpt-4.1-mini", + num_examples=5, + rollouts_per_example=3, + ), + environment_ids=["env-123"], + project_id="project-123", + ) + + assert result["evaluation_id"] == "eval-123" + assert captured["endpoint"] == "/hosted-evaluations" + assert captured["json"]["project_id"] == "project-123" + assert "projectId" not in captured["json"] + + +def test_create_hosted_evaluation_preserves_empty_project_id_payload(monkeypatch): + captured = {} + + class DummyConfig: + team_id = None + + class DummyAPIClient: + def __init__(self): + self.config = DummyConfig() + + def post(self, endpoint, json=None): + captured["endpoint"] = endpoint + captured["json"] = json + return {"evaluation_id": "eval-123"} + + monkeypatch.setattr("prime_cli.commands.evals.APIClient", DummyAPIClient) + + _create_hosted_evaluations( + HostedEvalConfig( + environment_id="env-123", + inference_model="openai/gpt-4.1-mini", + num_examples=5, + rollouts_per_example=3, + ), + project_id="", + ) + + assert captured["endpoint"] == "/hosted-evaluations" + assert captured["json"]["project_id"] == "" + + def test_create_hosted_evaluation_includes_sampling_args_in_payload(monkeypatch): captured = {} @@ -1303,11 +1375,20 @@ def test_eval_run_local_toml_passthrough(monkeypatch, tmp_path): """.strip() ) - def fake_run_eval_passthrough(environment, passthrough_args, skip_upload, env_path): + def fake_run_eval_passthrough( + environment, + passthrough_args, + skip_upload, + env_path, + project_id=None, + use_active_project=True, + ): captured["environment"] = environment captured["passthrough_args"] = passthrough_args captured["skip_upload"] = skip_upload captured["env_path"] = env_path + captured["project_id"] = project_id + captured["use_active_project"] = use_active_project monkeypatch.setattr("prime_cli.commands.evals.run_eval_passthrough", fake_run_eval_passthrough) @@ -1323,17 +1404,28 @@ def fake_run_eval_passthrough(environment, passthrough_args, skip_upload, env_pa "passthrough_args": [], "skip_upload": True, "env_path": None, + "project_id": None, + "use_active_project": False, } def test_eval_run_local_sampling_args_passthrough(monkeypatch): captured = {} - def fake_run_eval_passthrough(environment, passthrough_args, skip_upload, env_path): + def fake_run_eval_passthrough( + environment, + passthrough_args, + skip_upload, + env_path, + project_id=None, + use_active_project=True, + ): captured["environment"] = environment captured["passthrough_args"] = passthrough_args captured["skip_upload"] = skip_upload captured["env_path"] = env_path + captured["project_id"] = project_id + captured["use_active_project"] = use_active_project monkeypatch.setattr("prime_cli.commands.evals.run_eval_passthrough", fake_run_eval_passthrough) @@ -1349,6 +1441,89 @@ def fake_run_eval_passthrough(environment, passthrough_args, skip_upload, env_pa "passthrough_args": ["--sampling-args", '{"temperature":0.2}'], "skip_upload": False, "env_path": None, + "project_id": None, + "use_active_project": False, + } + + +def test_eval_run_local_no_project_disables_active_project(monkeypatch): + captured = {} + + def fake_run_eval_passthrough( + environment, + passthrough_args, + skip_upload, + env_path, + project_id=None, + use_active_project=True, + ): + captured["environment"] = environment + captured["passthrough_args"] = passthrough_args + captured["skip_upload"] = skip_upload + captured["env_path"] = env_path + captured["project_id"] = project_id + captured["use_active_project"] = use_active_project + + monkeypatch.setattr("prime_cli.commands.evals.run_eval_passthrough", fake_run_eval_passthrough) + + result = runner.invoke( + app, + ["eval", "run", "gsm8k", "--no-project"], + env={"PRIME_DISABLE_VERSION_CHECK": "1"}, + ) + + assert result.exit_code == 0, result.output + assert captured == { + "environment": "gsm8k", + "passthrough_args": [], + "skip_upload": False, + "env_path": None, + "project_id": None, + "use_active_project": False, + } + + +def test_eval_run_local_passes_resolved_project_without_active_relookup(monkeypatch): + captured = {} + resolve_calls = [] + + def fake_resolve_project_id(project, *, no_project=False): + resolve_calls.append((project, no_project)) + return "project-123" + + def fake_run_eval_passthrough( + environment, + passthrough_args, + skip_upload, + env_path, + project_id=None, + use_active_project=True, + ): + captured["environment"] = environment + captured["passthrough_args"] = passthrough_args + captured["skip_upload"] = skip_upload + captured["env_path"] = env_path + captured["project_id"] = project_id + captured["use_active_project"] = use_active_project + + monkeypatch.setattr("prime_cli.commands.evals.resolve_project_id", fake_resolve_project_id) + monkeypatch.setattr("prime_cli.commands.evals.run_eval_passthrough", fake_run_eval_passthrough) + + result = runner.invoke( + app, + ["eval", "run", "gsm8k"], + env={"PRIME_DISABLE_VERSION_CHECK": "1"}, + ) + + assert result.exit_code == 0, result.output + assert resolve_calls == [(None, False)] + assert captured == { + "environment": "gsm8k", + "passthrough_args": [], + "skip_upload": False, + "env_path": None, + "project_id": "project-123", + "use_active_project": False, } @@ -1409,6 +1584,41 @@ def fake_create_hosted_evaluations(config, environment_ids=None): assert captured["num_examples"] == -1 +def test_eval_run_hosted_forwards_resolved_project_id(monkeypatch): + captured = {} + + monkeypatch.setattr( + "prime_cli.commands.evals.resolve_project_id", + lambda project, no_project=False: "project-123", + ) + monkeypatch.setattr( + "prime_cli.commands.evals._resolve_hosted_environment", + lambda environment, env_dir_path=None, env_path=None: ("primeintellect/gsm8k", "env-123"), + ) + + def fake_create_hosted_evaluations(config, environment_ids=None, project_id=None): + captured["environment_ids"] = environment_ids + captured["project_id"] = project_id + return {"evaluation_id": "eval-123"} + + monkeypatch.setattr( + "prime_cli.commands.evals._create_hosted_evaluations", + fake_create_hosted_evaluations, + ) + + result = runner.invoke( + app, + ["eval", "run", "gsm8k", "--hosted", "--project", "project-slug"], + env={"PRIME_DISABLE_VERSION_CHECK": "1"}, + ) + + assert result.exit_code == 0, result.output + assert captured == { + "environment_ids": ["env-123"], + "project_id": "project-123", + } + + def test_eval_run_rejects_hosted_only_flags_without_hosted(): result = runner.invoke( app, diff --git a/packages/prime/tests/test_lab_setup.py b/packages/prime/tests/test_lab_setup.py index 561e72e5..f7f8078d 100644 --- a/packages/prime/tests/test_lab_setup.py +++ b/packages/prime/tests/test_lab_setup.py @@ -9,6 +9,7 @@ import pytest from prime_cli import lab_setup +from prime_cli.api.projects import Project from prime_cli.commands.lab import app as lab_cli_app from prime_cli.lab_agents import AgentCapability, known_agent_names from prime_cli.lab_setup import ( @@ -35,6 +36,10 @@ def _git_init(path: Path) -> None: @pytest.fixture(autouse=True) def fake_lab_asset_downloads(monkeypatch: Any) -> list[str]: + monkeypatch.delenv("PRIME_API_KEY", raising=False) + monkeypatch.delenv("PRIME_PROJECT_ID", raising=False) + monkeypatch.delenv("PRIME_TEAM_ID", raising=False) + urls: list[str] = [] skill_names = ( "create-environments", @@ -148,6 +153,27 @@ def _git(cwd: Path, *args: str) -> subprocess.CompletedProcess[str]: ) +def _project( + *, + id: str = "project-123", + name: str = "Alphabet Sort", + slug: str = "alphabet-sort", + team_id: str | None = None, +) -> Project: + return Project.model_validate( + { + "id": id, + "name": name, + "slug": slug, + "status": "ACTIVE", + "userId": "user-123", + "teamId": team_id, + "createdAt": "2026-05-20T12:00:00Z", + "updatedAt": "2026-05-20T12:00:00Z", + } + ) + + def test_lab_setup_parses_selected_agents_and_all() -> None: selected = parse_lab_setup_args(["--agent", "factory-droid,amp-code,claude-code,letta-code"]) all_agents = parse_lab_sync_args(["--agents", "all"]) @@ -162,6 +188,31 @@ def test_lab_setup_no_interactive_uses_codex_default() -> None: assert options.agents == ("codex",) +def test_lab_setup_parses_project_options() -> None: + existing = parse_lab_setup_args(["--agent", "codex", "--project", "project-123"]) + named = parse_lab_setup_args(["--agent", "codex", "--project-name", "Alphabet Sort Baselines"]) + skipped = parse_lab_setup_args(["--agent", "codex", "--no-project"]) + + assert existing.project_ref == "project-123" + assert existing.project_name is None + assert named.project_name == "Alphabet Sort Baselines" + assert named.project_ref is None + assert skipped.no_project is True + + +@pytest.mark.parametrize( + "args", + [ + ["--agent", "codex", "--project", "project-123", "--no-project"], + ["--agent", "codex", "--project-name", "Alphabet Sort", "--no-project"], + ["--agent", "codex", "--project", "project-123", "--project-name", "Alphabet Sort"], + ], +) +def test_lab_setup_rejects_conflicting_project_options(args: list[str]) -> None: + with pytest.raises(ValueError): + parse_lab_setup_args(args) + + def test_lab_setup_help_lists_supported_agents(capsys: Any) -> None: with pytest.raises(SystemExit) as exc_info: parse_lab_setup_args(["--help"]) @@ -292,6 +343,7 @@ def test_lab_setup_service_downloads_upstream_assets_without_agent_installs( assert "/CLAUDE.md" in gitignore.splitlines() assert "/CLAUDE.local.md" in gitignore.splitlines() assert "/.prime/" in gitignore.splitlines() + assert "/.agents/skills/" in gitignore.splitlines() assert (tmp_path / ".pi" / "extensions" / "prime-lab" / "index.ts").is_file() output = _render_emitted(emitted) assert "pi-acp" not in output @@ -324,6 +376,319 @@ def test_lab_setup_hygiene_preflight_nudges_tracked_guidance( assert _git(tmp_path, "ls-files", "--", "AGENTS.md").stdout == "" +def test_lab_setup_creates_default_project_from_workspace_name( + tmp_path: Path, + monkeypatch: Any, +) -> None: + workspace = tmp_path / "alphabet-sort" + monkeypatch.setenv("HOME", str(tmp_path / "home")) + monkeypatch.setenv("PRIME_API_KEY", "test-key") + monkeypatch.setattr(AGENT_WHICH, lambda _command: "/bin/tool") + monkeypatch.setattr("prime_cli.lab_setup.APIClient", lambda: object()) + captured: dict[str, Any] = {} + + class DummyProjectsClient: + def __init__(self, _api_client: object) -> None: + pass + + def create( + self, + name: str, + slug: str | None = None, + description: str | None = None, + team_id: str | None = None, + ) -> Project: + captured.update( + { + "name": name, + "slug": slug, + "description": description, + "team_id": team_id, + } + ) + return _project(name=name, slug="alphabet-sort", team_id=team_id) + + monkeypatch.setattr("prime_cli.lab_setup.ProjectsClient", DummyProjectsClient) + + emitted: list[Any] = [] + result = run_lab_setup_service( + LabSetupOptions(skip_install=True, skip_agents_md=True, agents=("codex",)), + workspace=workspace, + emit=emitted.append, + ) + + context = json.loads((workspace / ".prime" / "lab" / "context.json").read_text()) + assert result.exit_code == 0 + assert captured == { + "name": "Alphabet Sort", + "slug": None, + "description": None, + "team_id": None, + } + assert context["project_id"] == "project-123" + assert context["project_slug"] == "alphabet-sort" + assert "Created Lab project Alphabet Sort (alphabet-sort)" in _render_emitted(emitted) + + +def test_lab_setup_uses_existing_project_option( + tmp_path: Path, + monkeypatch: Any, +) -> None: + monkeypatch.setenv("HOME", str(tmp_path / "home")) + monkeypatch.setenv("PRIME_API_KEY", "test-key") + monkeypatch.setattr(AGENT_WHICH, lambda _command: "/bin/tool") + monkeypatch.setattr("prime_cli.lab_setup.APIClient", lambda: object()) + captured: dict[str, Any] = {} + + class DummyProjectsClient: + def __init__(self, _api_client: object) -> None: + pass + + def get(self, project_ref: str, team_id: str | None = None) -> Project: + captured["project_ref"] = project_ref + captured["team_id"] = team_id + return _project(name="Existing Project", slug="existing-project") + + monkeypatch.setattr("prime_cli.lab_setup.ProjectsClient", DummyProjectsClient) + + result = run_lab_setup_service( + LabSetupOptions( + skip_install=True, + skip_agents_md=True, + agents=("codex",), + project_ref="existing-project", + ), + workspace=tmp_path, + emit=lambda _item: None, + ) + + context = json.loads((tmp_path / ".prime" / "lab" / "context.json").read_text()) + assert result.exit_code == 0 + assert captured == {"project_ref": "existing-project", "team_id": None} + assert context["project_name"] == "Existing Project" + + +def test_lab_setup_project_option_overrides_active_project( + tmp_path: Path, + monkeypatch: Any, +) -> None: + monkeypatch.setenv("HOME", str(tmp_path / "home")) + monkeypatch.setenv("PRIME_API_KEY", "test-key") + monkeypatch.setattr(AGENT_WHICH, lambda _command: "/bin/tool") + monkeypatch.setattr("prime_cli.lab_setup.APIClient", lambda: object()) + context_dir = tmp_path / ".prime" / "lab" + context_dir.mkdir(parents=True) + (context_dir / "context.json").write_text( + json.dumps( + { + "project_id": "old-project", + "project_slug": "old-project", + "project_name": "Old Project", + "team_id": None, + "base_url": "https://api.primeintellect.ai", + } + ), + encoding="utf-8", + ) + + class DummyProjectsClient: + def __init__(self, _api_client: object) -> None: + pass + + def get(self, project_ref: str, team_id: str | None = None) -> Project: + assert project_ref == "new-project" + assert team_id is None + return _project( + id="new-project", + name="New Project", + slug="new-project", + ) + + monkeypatch.setattr("prime_cli.lab_setup.ProjectsClient", DummyProjectsClient) + + result = run_lab_setup_service( + LabSetupOptions( + skip_install=True, + skip_agents_md=True, + agents=("codex",), + project_ref="new-project", + ), + workspace=tmp_path, + emit=lambda _item: None, + ) + + context = json.loads((context_dir / "context.json").read_text()) + assert result.exit_code == 0 + assert context["project_id"] == "new-project" + assert context["project_name"] == "New Project" + + +def test_lab_setup_project_name_overrides_active_project( + tmp_path: Path, + monkeypatch: Any, +) -> None: + monkeypatch.setenv("HOME", str(tmp_path / "home")) + monkeypatch.setenv("PRIME_API_KEY", "test-key") + monkeypatch.setattr(AGENT_WHICH, lambda _command: "/bin/tool") + monkeypatch.setattr("prime_cli.lab_setup.APIClient", lambda: object()) + context_dir = tmp_path / ".prime" / "lab" + context_dir.mkdir(parents=True) + (context_dir / "context.json").write_text( + json.dumps( + { + "project_id": "old-project", + "project_slug": "old-project", + "project_name": "Old Project", + "team_id": None, + "base_url": "https://api.primeintellect.ai", + } + ), + encoding="utf-8", + ) + captured: dict[str, Any] = {} + + class DummyProjectsClient: + def __init__(self, _api_client: object) -> None: + pass + + def create( + self, + name: str, + slug: str | None = None, + description: str | None = None, + team_id: str | None = None, + ) -> Project: + captured.update( + { + "name": name, + "slug": slug, + "description": description, + "team_id": team_id, + } + ) + return _project( + id="new-project", + name=name, + slug="new-project", + ) + + monkeypatch.setattr("prime_cli.lab_setup.ProjectsClient", DummyProjectsClient) + + result = run_lab_setup_service( + LabSetupOptions( + skip_install=True, + skip_agents_md=True, + agents=("codex",), + project_name="New Project", + ), + workspace=tmp_path, + emit=lambda _item: None, + ) + + context = json.loads((context_dir / "context.json").read_text()) + assert result.exit_code == 0 + assert captured == { + "name": "New Project", + "slug": None, + "description": None, + "team_id": None, + } + assert context["project_id"] == "new-project" + assert context["project_name"] == "New Project" + + +def test_lab_setup_does_not_activate_mismatched_team_project( + tmp_path: Path, + monkeypatch: Any, +) -> None: + monkeypatch.setenv("HOME", str(tmp_path / "home")) + monkeypatch.setenv("PRIME_API_KEY", "test-key") + monkeypatch.setattr(AGENT_WHICH, lambda _command: "/bin/tool") + monkeypatch.setattr("prime_cli.lab_setup.APIClient", lambda: object()) + + class DummyProjectsClient: + def __init__(self, _api_client: object) -> None: + pass + + def get(self, project_ref: str, team_id: str | None = None) -> Project: + assert project_ref == "team-project" + assert team_id is None + return _project( + name="Team Project", + slug="team-project", + team_id="team-123", + ) + + monkeypatch.setattr("prime_cli.lab_setup.ProjectsClient", DummyProjectsClient) + emitted: list[Any] = [] + + result = run_lab_setup_service( + LabSetupOptions( + skip_install=True, + skip_agents_md=True, + agents=("codex",), + project_ref="team-project", + ), + workspace=tmp_path, + emit=emitted.append, + ) + + output = _render_emitted(emitted) + assert result.exit_code == 0 + assert not (tmp_path / ".prime" / "lab" / "context.json").exists() + assert "Cannot set an active project for team team-123" in output + + +def test_lab_setup_no_project_skips_project_api( + tmp_path: Path, + monkeypatch: Any, +) -> None: + monkeypatch.setenv("HOME", str(tmp_path / "home")) + monkeypatch.setenv("PRIME_API_KEY", "test-key") + monkeypatch.setattr(AGENT_WHICH, lambda _command: "/bin/tool") + monkeypatch.setattr( + "prime_cli.lab_setup.ProjectsClient", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + AssertionError("should not construct ProjectsClient") + ), + ) + emitted: list[Any] = [] + + result = run_lab_setup_service( + LabSetupOptions( + skip_install=True, + skip_agents_md=True, + agents=("codex",), + no_project=True, + ), + workspace=tmp_path, + emit=emitted.append, + ) + + assert result.exit_code == 0 + assert not (tmp_path / ".prime" / "lab" / "context.json").exists() + assert "Skipped Lab project setup (--no-project)" in _render_emitted(emitted) + + +def test_lab_setup_without_api_key_continues_without_project( + tmp_path: Path, + monkeypatch: Any, +) -> None: + monkeypatch.setenv("HOME", str(tmp_path / "home")) + monkeypatch.setattr(AGENT_WHICH, lambda _command: "/bin/tool") + emitted: list[Any] = [] + + result = run_lab_setup_service( + LabSetupOptions(skip_install=True, skip_agents_md=True, agents=("codex",)), + workspace=tmp_path, + emit=emitted.append, + ) + + assert result.exit_code == 0 + assert not (tmp_path / ".prime" / "lab" / "context.json").exists() + assert "Skipped Lab project setup because no API key is configured" in _render_emitted(emitted) + + def test_lab_setup_service_emits_post_setup_call_to_action( tmp_path: Path, monkeypatch: Any, diff --git a/packages/prime/tests/test_lab_view.py b/packages/prime/tests/test_lab_view.py index 59417e8f..99cddead 100644 --- a/packages/prime/tests/test_lab_view.py +++ b/packages/prime/tests/test_lab_view.py @@ -2178,6 +2178,8 @@ def test_prime_lab_doctor_service_checks_and_fixes_workspace(tmp_path: Path) -> assert "/AGENTS.md" in gitignore_lines assert "/CLAUDE.md" in gitignore_lines assert "/CLAUDE.local.md" in gitignore_lines + assert "/.prime/" in gitignore_lines + assert "/.agents/skills/" in gitignore_lines assert "/outputs/" in gitignore_lines assert "/prime-rl/" in gitignore_lines assert "/environments/AGENTS.md" in gitignore_lines diff --git a/packages/prime/tests/test_projects_cli.py b/packages/prime/tests/test_projects_cli.py new file mode 100644 index 00000000..96237157 --- /dev/null +++ b/packages/prime/tests/test_projects_cli.py @@ -0,0 +1,731 @@ +import json +from typing import Any, Optional + +import pytest +from prime_cli.api.projects import Project +from prime_cli.core import APIError +from prime_cli.main import app +from prime_cli.utils.projects import ( + get_active_project_id, + resolve_project_id, + write_project_context, +) +from typer.testing import CliRunner + +runner = CliRunner() + +TEST_ENV = { + "COLUMNS": "200", + "LINES": "50", + "PRIME_DISABLE_VERSION_CHECK": "1", + "PRIME_API_KEY": "test-key", + "PRIME_TEAM_ID": None, +} + + +def _project(team_id: Optional[str] = None) -> Project: + return Project.model_validate( + { + "id": "cmproject0000000000000001", + "name": "Battleship Baseline", + "slug": "battleship-baseline", + "status": "ACTIVE", + "userId": "cmuser000000000000000001", + "teamId": team_id, + "createdAt": "2026-05-20T12:00:00Z", + "updatedAt": "2026-05-20T12:00:00Z", + } + ) + + +def test_project_create_sets_active_context(monkeypatch, tmp_path) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr("prime_cli.commands.projects.APIClient", lambda: object()) + + class DummyProjectsClient: + def __init__(self, _api_client: Any) -> None: + pass + + def create( + self, + name: str, + slug: Optional[str] = None, + description: Optional[str] = None, + team_id: Optional[str] = None, + ) -> Project: + assert name == "Battleship Baseline" + assert slug is None + assert description is None + assert team_id is None + return _project() + + monkeypatch.setattr("prime_cli.commands.projects.ProjectsClient", DummyProjectsClient) + + result = runner.invoke(app, ["project", "create", "Battleship Baseline"], env=TEST_ENV) + + assert result.exit_code == 0, result.output + context_path = tmp_path / ".prime" / "lab" / "context.json" + assert context_path.exists() + context_text = context_path.read_text() + assert "cmproject0000000000000001" in context_text + assert context_text.endswith("\n") + + +def test_project_create_team_project_requires_matching_active_team_to_use( + monkeypatch, + tmp_path, +) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + + result = runner.invoke( + app, + ["project", "create", "Team Project", "--team-id", "team-123"], + env=TEST_ENV, + ) + + assert result.exit_code == 1 + assert "Cannot create and set an active project for team team-123" in result.output + assert not (tmp_path / ".prime" / "lab" / "context.json").exists() + + +def test_project_create_team_project_no_use_does_not_set_active_context( + monkeypatch, + tmp_path, +) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr("prime_cli.commands.projects.APIClient", lambda: object()) + + class DummyProjectsClient: + def __init__(self, _api_client: Any) -> None: + pass + + def create( + self, + name: str, + slug: Optional[str] = None, + description: Optional[str] = None, + team_id: Optional[str] = None, + ) -> Project: + assert name == "Team Project" + assert slug is None + assert description is None + assert team_id == "team-123" + return _project(team_id="team-123") + + monkeypatch.setattr("prime_cli.commands.projects.ProjectsClient", DummyProjectsClient) + + result = runner.invoke( + app, + ["project", "create", "Team Project", "--team-id", "team-123", "--no-use"], + env=TEST_ENV, + ) + + assert result.exit_code == 0, result.output + assert "Project created" in result.output + assert not (tmp_path / ".prime" / "lab" / "context.json").exists() + + +def test_project_create_team_project_sets_active_when_cli_team_matches( + monkeypatch, + tmp_path, +) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr("prime_cli.commands.projects.APIClient", lambda: object()) + + class DummyProjectsClient: + def __init__(self, _api_client: Any) -> None: + pass + + def create( + self, + name: str, + slug: Optional[str] = None, + description: Optional[str] = None, + team_id: Optional[str] = None, + ) -> Project: + assert name == "Battleship Baseline" + assert slug is None + assert description is None + assert team_id == "team-123" + return _project(team_id="team-123") + + monkeypatch.setattr("prime_cli.commands.projects.ProjectsClient", DummyProjectsClient) + + result = runner.invoke( + app, + ["project", "create", "Battleship Baseline"], + env={**TEST_ENV, "PRIME_TEAM_ID": "team-123"}, + ) + + assert result.exit_code == 0, result.output + context_path = tmp_path / ".prime" / "lab" / "context.json" + assert context_path.exists() + assert '"team_id": "team-123"' in context_path.read_text() + + +def test_project_current_reads_active_context(monkeypatch, tmp_path) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr("prime_cli.commands.projects.APIClient", lambda: object()) + + context_dir = tmp_path / ".prime" / "lab" + context_dir.mkdir(parents=True) + (context_dir / "context.json").write_text( + '{"project_id":"cmproject0000000000000001","team_id":null,' + '"base_url":"https://api.primeintellect.ai"}' + ) + + class DummyProjectsClient: + def __init__(self, _api_client: Any) -> None: + pass + + def get(self, project_ref: str, team_id: Optional[str] = None) -> Project: + assert project_ref == "cmproject0000000000000001" + assert team_id is None + return _project() + + monkeypatch.setattr("prime_cli.commands.projects.ProjectsClient", DummyProjectsClient) + + result = runner.invoke(app, ["project", "current"], env=TEST_ENV) + + assert result.exit_code == 0, result.output + assert "Battleship Baseline" in result.output + assert "battleship-baseline" in result.output + + +def test_project_current_json_api_error_keeps_project_shape_stable( + monkeypatch, + tmp_path, +) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr("prime_cli.commands.projects.APIClient", lambda: object()) + + context_dir = tmp_path / ".prime" / "lab" + context_dir.mkdir(parents=True) + (context_dir / "context.json").write_text( + '{"project_id":"cmproject0000000000000001","project_slug":"battleship-baseline",' + '"team_id":null,"base_url":"https://api.primeintellect.ai"}' + ) + + class DummyProjectsClient: + def __init__(self, _api_client: Any) -> None: + pass + + def get(self, project_ref: str, team_id: Optional[str] = None) -> Project: + assert project_ref == "cmproject0000000000000001" + assert team_id is None + raise APIError("offline") + + monkeypatch.setattr("prime_cli.commands.projects.ProjectsClient", DummyProjectsClient) + + result = runner.invoke(app, ["project", "current", "--output", "json"], env=TEST_ENV) + + assert result.exit_code == 0, result.output + payload = json.loads(result.output) + assert payload["project"] is None + assert payload["context"]["project_id"] == "cmproject0000000000000001" + + +def test_project_show_uses_active_context(monkeypatch, tmp_path) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr("prime_cli.commands.projects.APIClient", lambda: object()) + + context_dir = tmp_path / ".prime" / "lab" + context_dir.mkdir(parents=True) + (context_dir / "context.json").write_text( + '{"project_id":"cmproject0000000000000001","team_id":null,' + '"base_url":"https://api.primeintellect.ai"}' + ) + + class DummyProjectsClient: + def __init__(self, _api_client: Any) -> None: + pass + + def get(self, project_ref: str, team_id: Optional[str] = None) -> Project: + assert project_ref == "cmproject0000000000000001" + assert team_id is None + return _project() + + monkeypatch.setattr("prime_cli.commands.projects.ProjectsClient", DummyProjectsClient) + + result = runner.invoke(app, ["project", "show"], env=TEST_ENV) + + assert result.exit_code == 0, result.output + assert "Battleship Baseline" in result.output + assert "Not set" in result.output + + +def test_project_use_team_project_requires_matching_active_team( + monkeypatch, + tmp_path, +) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr("prime_cli.commands.projects.APIClient", lambda: object()) + + class DummyProjectsClient: + def __init__(self, _api_client: Any) -> None: + pass + + def get(self, project_ref: str, team_id: Optional[str] = None) -> Project: + assert project_ref == "battleship-baseline" + assert team_id == "team-123" + return _project(team_id="team-123") + + monkeypatch.setattr("prime_cli.commands.projects.ProjectsClient", DummyProjectsClient) + + result = runner.invoke( + app, + ["project", "use", "battleship-baseline", "--team-id", "team-123"], + env=TEST_ENV, + ) + + assert result.exit_code == 1 + assert "Cannot set an active project for team team-123" in result.output + assert not (tmp_path / ".prime" / "lab" / "context.json").exists() + + +def test_project_use_team_project_sets_active_when_cli_team_matches( + monkeypatch, + tmp_path, +) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr("prime_cli.commands.projects.APIClient", lambda: object()) + + class DummyProjectsClient: + def __init__(self, _api_client: Any) -> None: + pass + + def get(self, project_ref: str, team_id: Optional[str] = None) -> Project: + assert project_ref == "battleship-baseline" + assert team_id == "team-123" + return _project(team_id="team-123") + + monkeypatch.setattr("prime_cli.commands.projects.ProjectsClient", DummyProjectsClient) + + result = runner.invoke( + app, + ["project", "use", "battleship-baseline"], + env={**TEST_ENV, "PRIME_TEAM_ID": "team-123"}, + ) + + assert result.exit_code == 0, result.output + assert "Active project updated" in result.output + context_path = tmp_path / ".prime" / "lab" / "context.json" + assert context_path.exists() + assert '"team_id": "team-123"' in context_path.read_text() + + +def test_project_update_description(monkeypatch, tmp_path) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr("prime_cli.commands.projects.APIClient", lambda: object()) + + class DummyProjectsClient: + def __init__(self, _api_client: Any) -> None: + pass + + def update( + self, + project_ref: str, + name: Optional[str] = None, + slug: Optional[str] = None, + description: Optional[str] = None, + team_id: Optional[str] = None, + ) -> Project: + assert project_ref == "battleship-baseline" + assert name is None + assert slug is None + assert description == "Baseline and follow-up Battleship runs" + assert team_id is None + return _project().model_copy( + update={"description": "Baseline and follow-up Battleship runs"} + ) + + monkeypatch.setattr("prime_cli.commands.projects.ProjectsClient", DummyProjectsClient) + + result = runner.invoke( + app, + [ + "project", + "update", + "battleship-baseline", + "--description", + "Baseline and follow-up Battleship runs", + ], + env=TEST_ENV, + ) + + assert result.exit_code == 0, result.output + assert "Project updated" in result.output + assert "Baseline and follow-up Battleship runs" in result.output + + +def test_project_update_requires_field(monkeypatch, tmp_path) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + + result = runner.invoke(app, ["project", "update", "battleship-baseline"], env=TEST_ENV) + + assert result.exit_code == 1 + assert "Provide --name, --slug, --description, or --clear-description" in result.output + + +def test_project_assign_run_uses_active_project(monkeypatch, tmp_path) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr("prime_cli.commands.projects.APIClient", lambda: object()) + + context_dir = tmp_path / ".prime" / "lab" + context_dir.mkdir(parents=True) + (context_dir / "context.json").write_text( + '{"project_id":"cmproject0000000000000001","team_id":null,' + '"base_url":"https://api.primeintellect.ai"}' + ) + + class DummyProjectsClient: + def __init__(self, _api_client: Any) -> None: + pass + + def get(self, project_ref: str, team_id: Optional[str] = None) -> Project: + assert project_ref == "cmproject0000000000000001" + assert team_id is None + return _project() + + class DummyRLClient: + def __init__(self, _api_client: Any) -> None: + pass + + def update_run_project( + self, + run_id: str, + project_id: Optional[str], + *, + operation: str = "set", + move_adapters: bool = True, + ) -> tuple[object, int]: + assert run_id == "run-123" + assert project_id == "cmproject0000000000000001" + assert operation == "add" + assert move_adapters is True + return object(), 2 + + monkeypatch.setattr("prime_cli.commands.projects.ProjectsClient", DummyProjectsClient) + monkeypatch.setattr("prime_cli.commands.projects.RLClient", DummyRLClient) + + result = runner.invoke(app, ["project", "assign", "run", "run-123"], env=TEST_ENV) + + assert result.exit_code == 0, result.output + assert "Project assigned" in result.output + assert "Adapters Updated" in result.output + + +def test_project_remove_run_forwards_targeted_payload(monkeypatch, tmp_path) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr("prime_cli.commands.projects.APIClient", lambda: object()) + + captured = {} + + class DummyProjectsClient: + def __init__(self, _api_client: Any) -> None: + pass + + def get(self, project_ref: str, team_id: Optional[str] = None) -> Project: + assert project_ref == "battleship-baseline" + assert team_id is None + return _project() + + class DummyRLClient: + def __init__(self, _api_client: Any) -> None: + pass + + def update_run_project( + self, + run_id: str, + project_id: Optional[str], + *, + operation: str = "set", + move_adapters: bool = True, + ) -> tuple[object, int]: + captured.update( + { + "run_id": run_id, + "project_id": project_id, + "operation": operation, + "move_adapters": move_adapters, + } + ) + return object(), 0 + + monkeypatch.setattr("prime_cli.commands.projects.ProjectsClient", DummyProjectsClient) + monkeypatch.setattr("prime_cli.commands.projects.RLClient", DummyRLClient) + + result = runner.invoke( + app, + [ + "project", + "remove", + "run", + "run-123", + "battleship-baseline", + "--no-move-adapters", + "--output", + "json", + ], + env=TEST_ENV, + ) + + assert result.exit_code == 0, result.output + assert captured == { + "run_id": "run-123", + "project_id": "cmproject0000000000000001", + "operation": "remove", + "move_adapters": False, + } + + +def test_project_assign_adapter_forwards_project_payload(monkeypatch, tmp_path) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr("prime_cli.commands.projects.APIClient", lambda: object()) + + captured = {} + + class DummyProjectsClient: + def __init__(self, _api_client: Any) -> None: + pass + + def get(self, project_ref: str, team_id: Optional[str] = None) -> Project: + assert project_ref == "battleship-baseline" + assert team_id is None + return _project() + + class DummyDeploymentsClient: + def __init__(self, _api_client: Any) -> None: + pass + + def update_adapter_project( + self, + adapter_id: str, + project_id: Optional[str], + *, + operation: str = "set", + ) -> object: + captured.update( + { + "adapter_id": adapter_id, + "project_id": project_id, + "operation": operation, + } + ) + return object() + + monkeypatch.setattr("prime_cli.commands.projects.ProjectsClient", DummyProjectsClient) + monkeypatch.setattr("prime_cli.commands.projects.DeploymentsClient", DummyDeploymentsClient) + + result = runner.invoke( + app, + [ + "project", + "assign", + "adapter", + "adapter-123", + "battleship-baseline", + "--output", + "json", + ], + env=TEST_ENV, + ) + + assert result.exit_code == 0, result.output + assert captured == { + "adapter_id": "adapter-123", + "project_id": "cmproject0000000000000001", + "operation": "add", + } + + +def test_project_remove_adapter_without_project_clears_memberships( + monkeypatch, + tmp_path, +) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr("prime_cli.commands.projects.APIClient", lambda: object()) + + captured = {} + + class DummyDeploymentsClient: + def __init__(self, _api_client: Any) -> None: + pass + + def update_adapter_project( + self, + adapter_id: str, + project_id: Optional[str], + *, + operation: str = "set", + ) -> object: + captured.update( + { + "adapter_id": adapter_id, + "project_id": project_id, + "operation": operation, + } + ) + return object() + + monkeypatch.setattr("prime_cli.commands.projects.DeploymentsClient", DummyDeploymentsClient) + + result = runner.invoke( + app, + ["project", "remove", "adapter", "adapter-123", "--output", "json"], + env=TEST_ENV, + ) + + assert result.exit_code == 0, result.output + assert captured == { + "adapter_id": "adapter-123", + "project_id": None, + "operation": "clear", + } + + +def test_project_remove_eval_clears_project(monkeypatch, tmp_path) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr("prime_cli.commands.projects.APIClient", lambda: object()) + + class DummyEvalsClient: + def __init__(self, _api_client: Any) -> None: + pass + + def update_evaluation( + self, + evaluation_id: str, + *, + clear_project: bool = False, + ) -> dict: + assert evaluation_id == "eval-123" + assert clear_project is True + return {"evaluation_id": evaluation_id} + + monkeypatch.setattr("prime_cli.commands.projects.EvalsClient", DummyEvalsClient) + + result = runner.invoke(app, ["project", "remove", "eval", "eval-123"], env=TEST_ENV) + + assert result.exit_code == 0, result.output + assert "Project removed" in result.output + + +def test_project_remove_eval_rejects_targeted_project(monkeypatch, tmp_path) -> None: + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setattr("prime_cli.main.check_for_update", lambda: (False, None)) + monkeypatch.setattr("prime_cli.commands.projects.APIClient", lambda: object()) + + result = runner.invoke( + app, + ["project", "remove", "eval", "eval-123", "battleship-baseline"], + env=TEST_ENV, + ) + + assert result.exit_code == 1 + assert "Targeted removal from one project is not supported for evaluations" in result.output + + +def test_active_project_context_is_discovered_from_parent_workspace( + monkeypatch, + tmp_path, +) -> None: + monkeypatch.delenv("PRIME_API_BASE_URL", raising=False) + monkeypatch.delenv("PRIME_BASE_URL", raising=False) + monkeypatch.delenv("PRIME_PROJECT_ID", raising=False) + monkeypatch.delenv("PRIME_TEAM_ID", raising=False) + + context_dir = tmp_path / ".prime" / "lab" + context_dir.mkdir(parents=True) + (tmp_path / ".prime" / "lab.json").write_text("{}") + (context_dir / "context.json").write_text( + '{"project_id":"cmproject0000000000000001","team_id":null,' + '"base_url":"https://api.primeintellect.ai"}' + ) + + nested = tmp_path / "outputs" / "evals" / "gsm8k" / "run-123" + nested.mkdir(parents=True) + monkeypatch.chdir(nested) + monkeypatch.setenv("HOME", str(tmp_path)) + + assert get_active_project_id() == "cmproject0000000000000001" + + +def test_active_project_context_does_not_cross_lab_workspace_boundary( + monkeypatch, + tmp_path, +) -> None: + monkeypatch.delenv("PRIME_API_BASE_URL", raising=False) + monkeypatch.delenv("PRIME_BASE_URL", raising=False) + monkeypatch.delenv("PRIME_PROJECT_ID", raising=False) + monkeypatch.delenv("PRIME_TEAM_ID", raising=False) + + parent_context_dir = tmp_path / ".prime" / "lab" + parent_context_dir.mkdir(parents=True) + (tmp_path / ".prime" / "lab.json").write_text("{}") + (parent_context_dir / "context.json").write_text( + '{"project_id":"parent-project","team_id":null,"base_url":"https://api.primeintellect.ai"}' + ) + + child_workspace = tmp_path / "child-workspace" + (child_workspace / ".prime").mkdir(parents=True) + (child_workspace / ".prime" / "lab.json").write_text("{}") + nested = child_workspace / "outputs" / "evals" / "gsm8k" / "run-123" + nested.mkdir(parents=True) + + monkeypatch.chdir(nested) + monkeypatch.setenv("HOME", str(tmp_path)) + + assert get_active_project_id() is None + + +def test_write_project_context_uses_parent_lab_workspace(monkeypatch, tmp_path) -> None: + (tmp_path / ".prime" / "lab").mkdir(parents=True) + (tmp_path / ".prime" / "lab.json").write_text("{}") + nested = tmp_path / "outputs" / "evals" / "gsm8k" / "run-123" + nested.mkdir(parents=True) + + monkeypatch.chdir(nested) + monkeypatch.setenv("HOME", str(tmp_path)) + + write_project_context(_project()) + + assert (tmp_path / ".prime" / "lab" / "context.json").exists() + assert not (nested / ".prime" / "lab" / "context.json").exists() + + +def test_project_and_no_project_are_mutually_exclusive() -> None: + with pytest.raises(APIError, match="Cannot use --project and --no-project together"): + resolve_project_id("cmproject0000000000000001", no_project=True) diff --git a/packages/prime/tests/test_rl_api.py b/packages/prime/tests/test_rl_api.py index 6e62c64f..f743d271 100644 --- a/packages/prime/tests/test_rl_api.py +++ b/packages/prime/tests/test_rl_api.py @@ -9,6 +9,7 @@ class FakeAPIClient: def __init__(self) -> None: self.requests: list[tuple[str, dict[str, Any] | None]] = [] self.posts: list[tuple[str, dict[str, Any] | None]] = [] + self.patch_requests: list[tuple[str, str, dict[str, Any] | None]] = [] def get(self, endpoint: str, params: dict[str, Any] | None = None) -> dict[str, Any]: self.requests.append((endpoint, params)) @@ -40,11 +41,36 @@ def post(self, endpoint: str, json: dict[str, Any] | None = None) -> dict[str, A "batchSize": json["batch_size"] if json else 128, "baseModel": json["model"]["name"] if json else "model", "maxInflightRollouts": json.get("max_inflight_rollouts") if json else None, + "projectId": json.get("project_id") if json else None, "createdAt": "2026-05-17T00:00:00Z", "updatedAt": "2026-05-17T00:00:00Z", } } + def request( + self, + method: str, + endpoint: str, + json: dict[str, Any] | None = None, + ) -> dict[str, Any]: + self.patch_requests.append((method, endpoint, json)) + return { + "run": { + "id": "run-1", + "userId": "user-1", + "status": "QUEUED", + "rolloutsPerExample": 8, + "seqLen": 2048, + "maxSteps": 100, + "batchSize": 128, + "baseModel": "model", + "projectId": json.get("projectId") if json else None, + "createdAt": "2026-05-17T00:00:00Z", + "updatedAt": "2026-05-17T00:00:00Z", + }, + "adaptersUpdated": 3, + } + def test_get_distributions_preserves_chart_histogram_data() -> None: api_client = FakeAPIClient() @@ -134,3 +160,45 @@ def test_create_run_omits_default_rl_loss() -> None: assert api_client.posts[0][0] == "/rft/runs" assert "loss" not in api_client.posts[0][1] assert "teacher" not in api_client.posts[0][1] + + +def test_create_run_sends_project_id() -> None: + api_client = FakeAPIClient() + client = RLClient(api_client) # type: ignore[arg-type] + + run = client.create_run( + model_name="Qwen/Qwen3.5-0.8B", + environments=[{"id": "reverse-text"}], + project_id="project-123", + ) + + assert api_client.posts[0][0] == "/rft/runs" + assert api_client.posts[0][1]["project_id"] == "project-123" + assert "projectId" not in api_client.posts[0][1] + assert run.project_id == "project-123" + + +def test_update_run_project_sends_backend_payload_shape() -> None: + api_client = FakeAPIClient() + client = RLClient(api_client) # type: ignore[arg-type] + + run, adapters_updated = client.update_run_project( + "run-123", + "project-123", + operation="remove", + move_adapters=False, + ) + + assert api_client.patch_requests == [ + ( + "PATCH", + "/rft/runs/run-123/project", + { + "projectId": "project-123", + "operation": "remove", + "moveAdapters": False, + }, + ) + ] + assert run.project_id == "project-123" + assert adapters_updated == 3