diff --git a/trapdata/antenna/benchmark.py b/trapdata/antenna/benchmark.py index 4065fec..4d91109 100644 --- a/trapdata/antenna/benchmark.py +++ b/trapdata/antenna/benchmark.py @@ -49,7 +49,7 @@ def create_empty_result(reply_subject: str, image_id: str) -> AntennaTaskResult: def run_benchmark( job_id: int, base_url: str, - auth_token: str, + api_key: str, num_workers: int, batch_size: int, gpu_batch_size: int, @@ -60,7 +60,7 @@ def run_benchmark( Args: job_id: Job ID to process base_url: Antenna API base URL - auth_token: API authentication token + api_key: API key for authentication num_workers: Number of DataLoader workers batch_size: Batch size for API requests gpu_batch_size: GPU batch size for DataLoader @@ -68,7 +68,7 @@ def run_benchmark( # Create settings object settings = Settings() settings.antenna_api_base_url = base_url - settings.antenna_api_auth_token = auth_token + settings.antenna_api_key = api_key settings.antenna_api_batch_size = batch_size settings.localization_batch_size = gpu_batch_size settings.num_workers = num_workers @@ -134,7 +134,7 @@ def run_benchmark( # Send acknowledgments asynchronously result_poster.post_async( base_url=base_url, - auth_token=auth_token, + api_key=api_key, job_id=job_id, results=ack_results, ) @@ -156,7 +156,7 @@ def run_benchmark( if error_results and send_acks: result_poster.post_async( base_url=base_url, - auth_token=auth_token, + api_key=api_key, job_id=job_id, results=error_results, ) @@ -278,16 +278,16 @@ def main() -> int: args = parser.parse_args() # Get auth token from environment - auth_token = os.getenv("AMI_ANTENNA_API_AUTH_TOKEN", "") - if not auth_token: - print("ERROR: AMI_ANTENNA_API_AUTH_TOKEN environment variable not set") + api_key = os.getenv("AMI_ANTENNA_API_KEY", "") + if not api_key: + print("ERROR: AMI_ANTENNA_API_KEY environment variable not set") return 1 # Run the benchmark run_benchmark( job_id=args.job_id, base_url=args.base_url, - auth_token=auth_token, + api_key=api_key, num_workers=args.num_workers, batch_size=args.batch_size, gpu_batch_size=args.gpu_batch_size, diff --git a/trapdata/antenna/client.py b/trapdata/antenna/client.py index 5e8cde6..93cd55e 100644 --- a/trapdata/antenna/client.py +++ b/trapdata/antenna/client.py @@ -1,7 +1,5 @@ """Antenna API client for fetching jobs and posting results.""" -import socket - import requests from trapdata.antenna.schemas import ( @@ -15,22 +13,9 @@ from trapdata.common.logs import logger -def get_full_service_name(service_name: str) -> str: - """Build full service name with hostname. - - Args: - service_name: Base service name - - Returns: - Full service name with hostname appended - """ - hostname = socket.gethostname() - return f"{service_name} ({hostname})" - - def get_jobs( base_url: str, - auth_token: str, + api_key: str, pipeline_slugs: list[str], ) -> list[tuple[int, str]]: """Fetch job ids from the API for the given pipelines in a single request. @@ -39,13 +24,13 @@ def get_jobs( Args: base_url: Antenna API base URL (e.g., "http://localhost:8000/api/v2") - auth_token: API authentication token + api_key: API key for authentication pipeline_slugs: List of pipeline slugs to filter jobs Returns: List of (job_id, pipeline_slug) tuples (possibly empty) on success or error. """ - with get_http_session(auth_token) as session: + with get_http_session(api_key) as session: try: if not pipeline_slugs: return [] @@ -73,7 +58,7 @@ def get_jobs( def post_batch_results( base_url: str, - auth_token: str, + api_key: str, job_id: int, results: list[AntennaTaskResult], ) -> bool: @@ -82,7 +67,7 @@ def post_batch_results( Args: base_url: Antenna API base URL (e.g., "http://localhost:8000/api/v2") - auth_token: API authentication token + api_key: API key for authentication job_id: Job ID results: List of AntennaTaskResult objects @@ -92,7 +77,7 @@ def post_batch_results( url = f"{base_url.rstrip('/')}/jobs/{job_id}/result/" payload = AntennaTaskResults(results=results) - with get_http_session(auth_token) as session: + with get_http_session(api_key) as session: try: response = session.post( url, json=payload.model_dump(mode="json"), timeout=60 @@ -108,18 +93,18 @@ def post_batch_results( return False -def get_user_projects(base_url: str, auth_token: str) -> list[dict]: +def get_user_projects(base_url: str, api_key: str) -> list[dict]: """ Fetch all projects the user has access to. Args: base_url: Base URL for the API (should NOT include /api/v2) - auth_token: API authentication token + api_key: API key for authentication Returns: List of project dictionaries with 'id' and 'name' fields """ - with get_http_session(auth_token) as session: + with get_http_session(api_key) as session: try: url = f"{base_url.rstrip('/')}/projects/" response = session.get(url, timeout=30) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 7ecc7bd..574707a 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -106,7 +106,7 @@ class RESTDataset(torch.utils.data.IterableDataset): def __init__( self, base_url: str, - auth_token: str, + api_key: str, job_id: int, batch_size: int = 1, image_transforms: torchvision.transforms.Compose | None = None, @@ -116,14 +116,14 @@ def __init__( Args: base_url: Base URL for the API including /api/v2 (e.g., "http://localhost:8000/api/v2") - auth_token: API authentication token + api_key: API key for authentication job_id: The job ID to fetch tasks for batch_size: Number of tasks to request per batch image_transforms: Optional transforms to apply to loaded images """ super().__init__() self.base_url = base_url - self.auth_token = auth_token + self.api_key = api_key self.job_id = job_id self.batch_size = batch_size self.image_transforms = image_transforms or torchvision.transforms.ToTensor() @@ -143,7 +143,7 @@ def _ensure_sessions(self) -> None: issues with num_workers > 0 (SimpleQueue, socket objects, etc.). """ if self._api_session is None: - self._api_session = get_http_session(self.auth_token) + self._api_session = get_http_session(self.api_key) if self._image_fetch_session is None: self._image_fetch_session = get_http_session() if self._executor is None: @@ -420,13 +420,13 @@ def get_rest_dataloader( Args: job_id: Job ID to fetch tasks for settings: Settings object. Relevant fields: - - antenna_api_base_url / antenna_api_auth_token + - antenna_api_base_url / antenna_api_key - antenna_api_batch_size (tasks per API call and GPU batch size) - num_workers (DataLoader subprocesses) """ dataset = RESTDataset( base_url=settings.antenna_api_base_url, - auth_token=settings.antenna_api_auth_token, + api_key=settings.antenna_api_key, job_id=job_id, batch_size=settings.antenna_api_batch_size, ) diff --git a/trapdata/antenna/registration.py b/trapdata/antenna/registration.py index 4b41e31..3a27f86 100644 --- a/trapdata/antenna/registration.py +++ b/trapdata/antenna/registration.py @@ -1,8 +1,10 @@ """Pipeline registration with Antenna projects.""" +import platform +import socket + import requests -from trapdata.antenna.client import get_full_service_name from trapdata.antenna.schemas import ( AsyncPipelineRegistrationRequest, AsyncPipelineRegistrationResponse, @@ -13,11 +15,30 @@ from trapdata.settings import Settings, read_settings +def _get_version() -> str: + """Return the ami-data-companion package version, or 'unknown'.""" + try: + from importlib.metadata import PackageNotFoundError, version + + return version("trapdata") + except PackageNotFoundError: + return "unknown" + + +def _build_client_info() -> dict: + """Build a client_info dict with hostname, software, version, and platform.""" + return { + "hostname": socket.gethostname(), + "software": "ami-data-companion", + "version": _get_version(), + "platform": platform.platform(), + } + + def register_pipelines_for_project( base_url: str, - auth_token: str, + api_key: str, project_id: int, - service_name: str, pipeline_configs: list, ) -> tuple[bool, str]: """ @@ -25,18 +46,18 @@ def register_pipelines_for_project( Args: base_url: Base URL for the API (should NOT include /api/v2) - auth_token: API authentication token + api_key: API key for authentication project_id: Project ID to register pipelines for - service_name: Name of the processing service pipeline_configs: Pre-built pipeline configuration objects Returns: Tuple of (success: bool, message: str) """ - with get_http_session(auth_token=auth_token) as session: + with get_http_session(api_key=api_key) as session: try: registration_request = AsyncPipelineRegistrationRequest( - processing_service_name=service_name, pipelines=pipeline_configs + pipelines=pipeline_configs, + client_info=_build_client_info(), ) url = f"{base_url.rstrip('/')}/projects/{project_id}/pipelines/" @@ -70,7 +91,6 @@ def register_pipelines_for_project( def register_pipelines( project_ids: list[int], - service_name: str, settings: Settings | None = None, ) -> None: """ @@ -78,7 +98,6 @@ def register_pipelines( Args: project_ids: List of specific project IDs to register for. If empty, registers for all accessible projects. - service_name: Name of the processing service settings: Settings object with antenna_api_* configuration (defaults to read_settings()) """ # Import here to avoid circular import @@ -89,22 +108,12 @@ def register_pipelines( settings = read_settings() base_url = settings.antenna_api_base_url - auth_token = settings.antenna_api_auth_token + api_key = settings.antenna_api_key - if not auth_token: - logger.error("AMI_ANTENNA_API_AUTH_TOKEN environment variable not set") + if not api_key: + logger.error("AMI_ANTENNA_API_KEY environment variable not set") return - if not service_name or not service_name.strip(): - logger.error( - "Service name is required for registration. " - "Configure AMI_ANTENNA_SERVICE_NAME via environment variable, .env file, or Kivy settings." - ) - return - - # Add hostname to service name - full_service_name = get_full_service_name(service_name) - # Get projects to register for projects_to_process = [] if project_ids: @@ -116,7 +125,7 @@ def register_pipelines( else: # Fetch all accessible projects logger.info("Fetching all accessible projects...") - all_projects = get_user_projects(base_url, auth_token) + all_projects = get_user_projects(base_url, api_key) projects_to_process = all_projects logger.info(f"Found {len(projects_to_process)} accessible projects") @@ -146,9 +155,8 @@ def register_pipelines( success, message = register_pipelines_for_project( base_url=base_url, - auth_token=auth_token, + api_key=api_key, project_id=project_id, - service_name=full_service_name, pipeline_configs=pipeline_configs, ) @@ -164,7 +172,6 @@ def register_pipelines( # Summary report logger.info("\n=== Registration Summary ===") - logger.info(f"Service name: {full_service_name}") logger.info(f"Total projects processed: {len(projects_to_process)}") logger.info(f"Successful registrations: {len(successful_registrations)}") logger.info(f"Failed registrations: {len(failed_registrations)}") diff --git a/trapdata/antenna/result_posting.py b/trapdata/antenna/result_posting.py index cd76737..26edf13 100644 --- a/trapdata/antenna/result_posting.py +++ b/trapdata/antenna/result_posting.py @@ -14,7 +14,7 @@ Usage: poster = ResultPoster(max_pending=5) - poster.post_async(base_url, auth_token, job_id, results) + poster.post_async(base_url, api_key, job_id, results) metrics = poster.get_metrics() poster.shutdown() """ @@ -60,7 +60,7 @@ class ResultPoster: Example: poster = ResultPoster(max_pending=10) - poster.post_async(base_url, auth_token, job_id, results) + poster.post_async(base_url, api_key, job_id, results) metrics = poster.get_metrics() poster.shutdown() """ @@ -82,7 +82,7 @@ def __init__( def post_async( self, base_url: str, - auth_token: str, + api_key: str, job_id: int, results: list, ) -> None: @@ -93,7 +93,7 @@ def post_async( Args: base_url: Antenna API base URL - auth_token: API authentication token + api_key: API key for authentication job_id: Job ID for the results results: List of result objects to post """ @@ -134,7 +134,7 @@ def post_async( future = self.executor.submit( self._post_with_timing, base_url, - auth_token, + api_key, job_id, results, start_time, @@ -149,7 +149,7 @@ def post_async( def _post_with_timing( self, base_url: str, - auth_token: str, + api_key: str, job_id: int, results: list, start_time: float, @@ -158,7 +158,7 @@ def _post_with_timing( Args: base_url: Antenna API base URL - auth_token: API authentication token + api_key: API key for authentication job_id: Job ID for the results results: List of result objects to post start_time: Timestamp when the post was initiated @@ -167,7 +167,7 @@ def _post_with_timing( True if successful, False otherwise """ try: - success = post_batch_results(base_url, auth_token, job_id, results) + success = post_batch_results(base_url, api_key, job_id, results) elapsed_time = time.time() - start_time with self._metrics_lock: diff --git a/trapdata/antenna/schemas.py b/trapdata/antenna/schemas.py index be64eef..1389891 100644 --- a/trapdata/antenna/schemas.py +++ b/trapdata/antenna/schemas.py @@ -85,11 +85,15 @@ class AntennaResultPostResponse(pydantic.BaseModel): class AsyncPipelineRegistrationRequest(pydantic.BaseModel): """ - Request to register pipelines from an async processing service + Request to register pipelines from an async processing service. + + The server identifies the processing service from the API key, + so no service name is needed. Optional client_info provides + metadata about the client for diagnostics. """ - processing_service_name: str pipelines: list[PipelineConfigResponse] = [] + client_info: dict | None = None class AsyncPipelineRegistrationResponse(pydantic.BaseModel): diff --git a/trapdata/antenna/tests/test_memory_leak.py b/trapdata/antenna/tests/test_memory_leak.py index a09c14c..4510174 100644 --- a/trapdata/antenna/tests/test_memory_leak.py +++ b/trapdata/antenna/tests/test_memory_leak.py @@ -51,7 +51,7 @@ def setUp(self): def _make_settings(self): settings = MagicMock() settings.antenna_api_base_url = "http://testserver/api/v2" - settings.antenna_api_auth_token = "test-token" + settings.antenna_api_key = "test-api-key" settings.antenna_api_batch_size = 2 settings.num_workers = 0 settings.localization_batch_size = 2 diff --git a/trapdata/antenna/tests/test_worker.py b/trapdata/antenna/tests/test_worker.py index f6b9079..2cedbb3 100644 --- a/trapdata/antenna/tests/test_worker.py +++ b/trapdata/antenna/tests/test_worker.py @@ -41,7 +41,7 @@ def test_dataloader_starts_with_num_workers(self): """Creating an iterator pickles the dataset to send to worker subprocesses.""" dataset = RESTDataset( base_url="http://localhost:1/api/v2", - auth_token="test-token", + api_key="test-api-key", job_id=1, batch_size=4, ) @@ -167,7 +167,7 @@ def _make_dataset(self, job_id: int = 42, batch_size: int = 2) -> RESTDataset: base_url="http://testserver/api/v2", job_id=job_id, batch_size=batch_size, - auth_token="test-token", + api_key="test-api-key", ) def test_multiple_batches(self): @@ -223,7 +223,7 @@ def test_returns_job_ids(self): with patch_antenna_api_requests(self.antenna_client): result = get_jobs( "http://testserver/api/v2", - "test-token", + "test-api-key", ["moths_2024"], ) @@ -256,7 +256,7 @@ def _make_settings(self): """Create mock settings for worker.""" settings = MagicMock() settings.antenna_api_base_url = "http://testserver/api/v2" - settings.antenna_api_auth_token = "test-token" + settings.antenna_api_key = "test-api-key" settings.antenna_api_batch_size = 2 settings.num_workers = 0 # Disable multiprocessing for tests settings.localization_batch_size = 2 # Real integer for batch processing @@ -417,7 +417,7 @@ def setUp(self): def _make_settings(self): settings = MagicMock() settings.antenna_api_base_url = "http://testserver/api/v2" - settings.antenna_api_auth_token = "test-token" + settings.antenna_api_key = "test-api-key" settings.antenna_api_batch_size = 2 settings.num_workers = 0 settings.localization_batch_size = 2 # Real integer for batch processing @@ -455,9 +455,8 @@ def test_full_workflow_with_real_inference(self): ] success, _ = register_pipelines_for_project( base_url="http://testserver/api/v2", - auth_token="test-token", + api_key="test-api-key", project_id=1, - service_name="Test Worker", pipeline_configs=pipeline_configs, ) assert success is True @@ -465,7 +464,7 @@ def test_full_workflow_with_real_inference(self): # Step 2: Get jobs jobs = get_jobs( "http://testserver/api/v2", - "test-token", + "test-api-key", [pipeline_slug], ) job_ids = [job_id for job_id, _ in jobs] diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index 2b7e1db..a440797 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -10,7 +10,7 @@ import torch import torch.multiprocessing as mp -from trapdata.antenna.client import get_full_service_name, get_jobs +from trapdata.antenna.client import get_jobs from trapdata.antenna.datasets import CUDAPrefetcher, get_rest_dataloader from trapdata.antenna.result_posting import ResultPoster from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError @@ -38,18 +38,11 @@ def run_worker(pipelines: list[str]): """ settings = read_settings() - # Validate auth token - if not settings.antenna_api_auth_token: + # Validate API key + if not settings.antenna_api_key: raise ValueError( - "AMI_ANTENNA_API_AUTH_TOKEN environment variable must be set. " - "Get your auth token from your Antenna project settings." - ) - - # Validate service name - if not settings.antenna_service_name or not settings.antenna_service_name.strip(): - raise ValueError( - "AMI_ANTENNA_SERVICE_NAME configuration setting must be set. " - "Configure it via environment variable or .env file." + "AMI_ANTENNA_API_KEY environment variable must be set. " + "Get your API key from your Antenna project settings." ) gpu_count = torch.cuda.device_count() @@ -86,10 +79,6 @@ def _worker_loop(gpu_id: int, pipelines: list[str]): f"AMI worker instance {gpu_id} pinned to GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}" ) - # Build full service name with hostname - full_service_name = get_full_service_name(settings.antenna_service_name) - logger.info(f"Running worker as: {full_service_name}") - while True: # TODO CGJS: Support pulling and prioritizing single image tasks, which are used in interactive testing # These should probably come from a dedicated endpoint and should preempt batch jobs under the assumption that they @@ -100,7 +89,7 @@ def _worker_loop(gpu_id: int, pipelines: list[str]): ) jobs = get_jobs( base_url=settings.antenna_api_base_url, - auth_token=settings.antenna_api_auth_token, + api_key=settings.antenna_api_key, pipeline_slugs=pipelines, ) for job_id, pipeline in jobs: @@ -500,7 +489,7 @@ def _process_job( # Post results asynchronously (non-blocking) result_poster.post_async( settings.antenna_api_base_url, - settings.antenna_api_auth_token, + settings.antenna_api_key, job_id, batch_results, ) diff --git a/trapdata/api/utils.py b/trapdata/api/utils.py index 3d8e02a..f5e6459 100644 --- a/trapdata/api/utils.py +++ b/trapdata/api/utils.py @@ -38,7 +38,7 @@ def get_crop_fname(source_image: SourceImage, bbox: BoundingBox) -> str: return f"{source_name}/{bbox_name}-{timestamp}.jpg" -def get_http_session(auth_token: str | None = None) -> requests.Session: +def get_http_session(api_key: str | None = None) -> requests.Session: """ Create an HTTP session with retry logic for transient failures. @@ -47,10 +47,10 @@ def get_http_session(auth_token: str | None = None) -> requests.Session: and network failures, NOT on client errors (4XX). Only GET requests are retried. TODO: This will likely become part of an AntennaClient class that encapsulates - base_url, auth_token, and session management. See docs/claude/planning/antenna-client.md + base_url, api_key, and session management. See docs/claude/planning/antenna-client.md Args: - auth_token: Optional API token. If provided, adds "Token {auth_token}" header. + api_key: Optional API key. If provided, adds "Api-Key {api_key}" header. Returns: Configured requests.Session with retry adapter mounted @@ -69,7 +69,7 @@ def get_http_session(auth_token: str | None = None) -> requests.Session: session.mount("http://", adapter) session.mount("https://", adapter) - if auth_token: - session.headers["Authorization"] = f"Token {auth_token}" + if api_key: + session.headers["Authorization"] = f"Api-Key {api_key}" return session diff --git a/trapdata/cli/worker.py b/trapdata/cli/worker.py index f1b5782..ed74401 100644 --- a/trapdata/cli/worker.py +++ b/trapdata/cli/worker.py @@ -64,18 +64,13 @@ def register( This command registers all available pipeline configurations with the Antenna platform for the specified projects (or all accessible projects if none specified). - The service name is read from the AMI_ANTENNA_SERVICE_NAME configuration setting. - Hostname will be added automatically to the service name. + The processing service is identified by the API key. Examples: ami worker register --project 1 --project 2 ami worker register # registers for all accessible projects """ from trapdata.antenna.registration import register_pipelines - from trapdata.settings import read_settings - settings = read_settings() project_ids = project if project else [] - register_pipelines( - project_ids=project_ids, service_name=settings.antenna_service_name - ) + register_pipelines(project_ids=project_ids) diff --git a/trapdata/settings.py b/trapdata/settings.py index b07e043..03dfb9b 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -39,8 +39,7 @@ class Settings(BaseSettings): # Antenna API worker settings antenna_api_base_url: str = "http://localhost:8000/api/v2" - antenna_api_auth_token: str = "" - antenna_service_name: str = "AMI Data Companion" + antenna_api_key: str = "" antenna_api_batch_size: int = 24 @pydantic.field_validator("image_base_path", "user_data_path") @@ -158,9 +157,9 @@ class Config: "kivy_type": "string", "kivy_section": "antenna", }, - "antenna_api_auth_token": { - "title": "Antenna API Token", - "description": "Authentication token for your Antenna project", + "antenna_api_key": { + "title": "Antenna API Key", + "description": "API key for authenticating with Antenna (format: prefix.secret)", "kivy_type": "string", "kivy_section": "antenna", }, @@ -170,12 +169,6 @@ class Config: "kivy_type": "numeric", "kivy_section": "antenna", }, - "antenna_service_name": { - "title": "Antenna Service Name", - "description": "Name for the processing service registration (hostname will be added automatically)", - "kivy_type": "string", - "kivy_section": "antenna", - }, } @classmethod