Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions trapdata/antenna/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def create_empty_result(reply_subject: str, image_id: str) -> AntennaTaskResult:
def run_benchmark(
job_id: int,
base_url: str,
auth_token: str,
api_key: str,
num_workers: int,
batch_size: int,
gpu_batch_size: int,
Expand All @@ -60,15 +60,15 @@ def run_benchmark(
Args:
job_id: Job ID to process
base_url: Antenna API base URL
auth_token: API authentication token
api_key: API key for authentication
num_workers: Number of DataLoader workers
batch_size: Batch size for API requests
gpu_batch_size: GPU batch size for DataLoader
"""
# Create settings object
settings = Settings()
settings.antenna_api_base_url = base_url
settings.antenna_api_auth_token = auth_token
settings.antenna_api_key = api_key
settings.antenna_api_batch_size = batch_size
settings.localization_batch_size = gpu_batch_size
settings.num_workers = num_workers
Expand Down Expand Up @@ -134,7 +134,7 @@ def run_benchmark(
# Send acknowledgments asynchronously
result_poster.post_async(
base_url=base_url,
auth_token=auth_token,
api_key=api_key,
job_id=job_id,
results=ack_results,
)
Expand All @@ -156,7 +156,7 @@ def run_benchmark(
if error_results and send_acks:
result_poster.post_async(
base_url=base_url,
auth_token=auth_token,
api_key=api_key,
job_id=job_id,
results=error_results,
)
Expand Down Expand Up @@ -278,16 +278,16 @@ def main() -> int:
args = parser.parse_args()

# Get auth token from environment
auth_token = os.getenv("AMI_ANTENNA_API_AUTH_TOKEN", "")
if not auth_token:
print("ERROR: AMI_ANTENNA_API_AUTH_TOKEN environment variable not set")
api_key = os.getenv("AMI_ANTENNA_API_KEY", "")
if not api_key:
print("ERROR: AMI_ANTENNA_API_KEY environment variable not set")
return 1

# Run the benchmark
run_benchmark(
job_id=args.job_id,
base_url=args.base_url,
auth_token=auth_token,
api_key=api_key,
num_workers=args.num_workers,
batch_size=args.batch_size,
gpu_batch_size=args.gpu_batch_size,
Expand Down
33 changes: 9 additions & 24 deletions trapdata/antenna/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Antenna API client for fetching jobs and posting results."""

import socket

import requests

from trapdata.antenna.schemas import (
Expand All @@ -15,22 +13,9 @@
from trapdata.common.logs import logger


def get_full_service_name(service_name: str) -> str:
"""Build full service name with hostname.

Args:
service_name: Base service name

Returns:
Full service name with hostname appended
"""
hostname = socket.gethostname()
return f"{service_name} ({hostname})"


def get_jobs(
base_url: str,
auth_token: str,
api_key: str,
pipeline_slugs: list[str],
) -> list[tuple[int, str]]:
"""Fetch job ids from the API for the given pipelines in a single request.
Expand All @@ -39,13 +24,13 @@ def get_jobs(

Args:
base_url: Antenna API base URL (e.g., "http://localhost:8000/api/v2")
auth_token: API authentication token
api_key: API key for authentication
pipeline_slugs: List of pipeline slugs to filter jobs

Returns:
List of (job_id, pipeline_slug) tuples (possibly empty) on success or error.
"""
with get_http_session(auth_token) as session:
with get_http_session(api_key) as session:
try:
if not pipeline_slugs:
return []
Expand Down Expand Up @@ -73,7 +58,7 @@ def get_jobs(

def post_batch_results(
base_url: str,
auth_token: str,
api_key: str,
job_id: int,
results: list[AntennaTaskResult],
) -> bool:
Expand All @@ -82,7 +67,7 @@ def post_batch_results(

Args:
base_url: Antenna API base URL (e.g., "http://localhost:8000/api/v2")
auth_token: API authentication token
api_key: API key for authentication
job_id: Job ID
results: List of AntennaTaskResult objects

Expand All @@ -92,7 +77,7 @@ def post_batch_results(
url = f"{base_url.rstrip('/')}/jobs/{job_id}/result/"
payload = AntennaTaskResults(results=results)

with get_http_session(auth_token) as session:
with get_http_session(api_key) as session:
try:
response = session.post(
url, json=payload.model_dump(mode="json"), timeout=60
Expand All @@ -108,18 +93,18 @@ def post_batch_results(
return False


def get_user_projects(base_url: str, auth_token: str) -> list[dict]:
def get_user_projects(base_url: str, api_key: str) -> list[dict]:
"""
Fetch all projects the user has access to.

Args:
base_url: Base URL for the API (should NOT include /api/v2)
auth_token: API authentication token
api_key: API key for authentication

Returns:
List of project dictionaries with 'id' and 'name' fields
"""
with get_http_session(auth_token) as session:
with get_http_session(api_key) as session:
try:
url = f"{base_url.rstrip('/')}/projects/"
response = session.get(url, timeout=30)
Expand Down
12 changes: 6 additions & 6 deletions trapdata/antenna/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class RESTDataset(torch.utils.data.IterableDataset):
def __init__(
self,
base_url: str,
auth_token: str,
api_key: str,
job_id: int,
batch_size: int = 1,
image_transforms: torchvision.transforms.Compose | None = None,
Expand All @@ -116,14 +116,14 @@ def __init__(

Args:
base_url: Base URL for the API including /api/v2 (e.g., "http://localhost:8000/api/v2")
auth_token: API authentication token
api_key: API key for authentication
job_id: The job ID to fetch tasks for
batch_size: Number of tasks to request per batch
image_transforms: Optional transforms to apply to loaded images
"""
super().__init__()
self.base_url = base_url
self.auth_token = auth_token
self.api_key = api_key
self.job_id = job_id
self.batch_size = batch_size
self.image_transforms = image_transforms or torchvision.transforms.ToTensor()
Expand All @@ -143,7 +143,7 @@ def _ensure_sessions(self) -> None:
issues with num_workers > 0 (SimpleQueue, socket objects, etc.).
"""
if self._api_session is None:
self._api_session = get_http_session(self.auth_token)
self._api_session = get_http_session(self.api_key)
if self._image_fetch_session is None:
self._image_fetch_session = get_http_session()
if self._executor is None:
Expand Down Expand Up @@ -420,13 +420,13 @@ def get_rest_dataloader(
Args:
job_id: Job ID to fetch tasks for
settings: Settings object. Relevant fields:
- antenna_api_base_url / antenna_api_auth_token
- antenna_api_base_url / antenna_api_key
- antenna_api_batch_size (tasks per API call and GPU batch size)
- num_workers (DataLoader subprocesses)
"""
dataset = RESTDataset(
base_url=settings.antenna_api_base_url,
auth_token=settings.antenna_api_auth_token,
api_key=settings.antenna_api_key,
job_id=job_id,
batch_size=settings.antenna_api_batch_size,
)
Expand Down
59 changes: 33 additions & 26 deletions trapdata/antenna/registration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Pipeline registration with Antenna projects."""

import platform
import socket

import requests

from trapdata.antenna.client import get_full_service_name
from trapdata.antenna.schemas import (
AsyncPipelineRegistrationRequest,
AsyncPipelineRegistrationResponse,
Expand All @@ -13,30 +15,49 @@
from trapdata.settings import Settings, read_settings


def _get_version() -> str:
"""Return the ami-data-companion package version, or 'unknown'."""
try:
from importlib.metadata import PackageNotFoundError, version

return version("trapdata")
except PackageNotFoundError:
return "unknown"


def _build_client_info() -> dict:
"""Build a client_info dict with hostname, software, version, and platform."""
return {
"hostname": socket.gethostname(),
"software": "ami-data-companion",
"version": _get_version(),
"platform": platform.platform(),
}


def register_pipelines_for_project(
base_url: str,
auth_token: str,
api_key: str,
project_id: int,
service_name: str,
pipeline_configs: list,
) -> tuple[bool, str]:
"""
Register all available pipelines for a specific project.

Args:
base_url: Base URL for the API (should NOT include /api/v2)
auth_token: API authentication token
api_key: API key for authentication
project_id: Project ID to register pipelines for
service_name: Name of the processing service
pipeline_configs: Pre-built pipeline configuration objects

Returns:
Tuple of (success: bool, message: str)
"""
with get_http_session(auth_token=auth_token) as session:
with get_http_session(api_key=api_key) as session:
try:
registration_request = AsyncPipelineRegistrationRequest(
processing_service_name=service_name, pipelines=pipeline_configs
pipelines=pipeline_configs,
client_info=_build_client_info(),
)

url = f"{base_url.rstrip('/')}/projects/{project_id}/pipelines/"
Expand Down Expand Up @@ -70,15 +91,13 @@ def register_pipelines_for_project(

def register_pipelines(
project_ids: list[int],
service_name: str,
settings: Settings | None = None,
) -> None:
"""
Register pipelines for specified projects or all accessible projects.

Args:
project_ids: List of specific project IDs to register for. If empty, registers for all accessible projects.
service_name: Name of the processing service
settings: Settings object with antenna_api_* configuration (defaults to read_settings())
"""
# Import here to avoid circular import
Expand All @@ -89,22 +108,12 @@ def register_pipelines(
settings = read_settings()

base_url = settings.antenna_api_base_url
auth_token = settings.antenna_api_auth_token
api_key = settings.antenna_api_key

if not auth_token:
logger.error("AMI_ANTENNA_API_AUTH_TOKEN environment variable not set")
if not api_key:
logger.error("AMI_ANTENNA_API_KEY environment variable not set")
return

if not service_name or not service_name.strip():
logger.error(
"Service name is required for registration. "
"Configure AMI_ANTENNA_SERVICE_NAME via environment variable, .env file, or Kivy settings."
)
return

# Add hostname to service name
full_service_name = get_full_service_name(service_name)

# Get projects to register for
projects_to_process = []
if project_ids:
Expand All @@ -116,7 +125,7 @@ def register_pipelines(
else:
# Fetch all accessible projects
logger.info("Fetching all accessible projects...")
all_projects = get_user_projects(base_url, auth_token)
all_projects = get_user_projects(base_url, api_key)
projects_to_process = all_projects
logger.info(f"Found {len(projects_to_process)} accessible projects")

Expand Down Expand Up @@ -146,9 +155,8 @@ def register_pipelines(

success, message = register_pipelines_for_project(
base_url=base_url,
auth_token=auth_token,
api_key=api_key,
project_id=project_id,
service_name=full_service_name,
pipeline_configs=pipeline_configs,
)

Expand All @@ -164,7 +172,6 @@ def register_pipelines(

# Summary report
logger.info("\n=== Registration Summary ===")
logger.info(f"Service name: {full_service_name}")
logger.info(f"Total projects processed: {len(projects_to_process)}")
logger.info(f"Successful registrations: {len(successful_registrations)}")
logger.info(f"Failed registrations: {len(failed_registrations)}")
Expand Down
Loading
Loading