diff --git a/.env.example b/.env.example index 0edf1938..2a9178db 100644 --- a/.env.example +++ b/.env.example @@ -8,3 +8,9 @@ AMI_CLASSIFICATION_THRESHOLD=0.6 AMI_LOCALIZATION_BATCH_SIZE=2 AMI_CLASSIFICATION_BATCH_SIZE=20 AMI_NUM_WORKERS=1 + +# Antenna API Worker Settings (for processing jobs from Antenna platform) +# See: https://github.com/RolnickLab/antenna +AMI_ANTENNA_API_BASE_URL=http://localhost:8000/api/v2 +AMI_ANTENNA_API_AUTH_TOKEN=your_antenna_auth_token_here +AMI_ANTENNA_API_BATCH_SIZE=4 diff --git a/.gitignore b/.gitignore index d2b7f1e3..f23bd5c8 100644 --- a/.gitignore +++ b/.gitignore @@ -142,3 +142,6 @@ db_data/ # Test files sample_images bak + +# Local scratch for moving untracked files +scratch/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 20901bdc..fe9c79ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: types: [pyi] - repo: https://github.com/pycqa/flake8 - rev: 3.8.3 + rev: 4.0.0 hooks: - id: flake8 files: . diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..23709de7 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,29 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + }, + { + "name": "Run worker", + "type": "debugpy", + "request": "launch", + "module": "trapdata.cli.base", + "args": ["worker"] + }, + { + "name": "Run api", + "type": "debugpy", + "request": "launch", + "module": "trapdata.cli.base", + "args": ["api"] + } + ] +} diff --git a/CLAUDE.md b/CLAUDE.md index 13776c45..89534b9d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -13,8 +13,9 @@ This file helps AI agents (like Claude) work efficiently with the AMI Data Compa 3. **Always prefer command line tools** to avoid expensive API requests (e.g., use git and jq instead of reading whole files) 4. **Use bulk operations and prefetch patterns** to minimize database queries 5. **Commit often** - Small, focused commits make debugging easier -6. **Use TDD whenever possible** - Tests prevent regressions and document expected behavior -7. **Keep it simple** - Always think hard and evaluate more complex approaches and alternative approaches before moving forward +6. **Use `git add -p` for staging** - Interactive staging to add only relevant changes, creating logical commits +7. **Use TDD whenever possible** - Tests prevent regressions and document expected behavior +8. **Keep it simple** - Always think hard and evaluate more complex approaches and alternative approaches before moving forward ### Think Holistically diff --git a/README.md b/README.md index 30bab6e0..778bfaef 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,60 @@ ami api View the interactive API docs at http://localhost:2000/ +## Running the Antenna Worker + +The worker polls the Antenna platform API for queued image processing jobs, downloads images, runs detection and classification, and posts results back to Antenna. + +**Setup:** + +1. Get your Antenna auth token from your Antenna project settings +2. Configure the worker in `.env`: + +```sh +AMI_ANTENNA_API_BASE_URL=https://antenna.insectai.org/api/v2 # Or your Antenna instance +AMI_ANTENNA_API_AUTH_TOKEN=your_token_here +AMI_ANTENNA_API_BATCH_SIZE=4 +AMI_NUM_WORKERS=2 # Safe for REST API (atomic task dequeue) +``` + +**Register pipelines (optional):** + +Register available ML pipelines with your Antenna projects: + +```sh +ami worker register "My Worker Name" --project 1 --project 2 +# Or register for all accessible projects: +ami worker register "My Worker Name" +``` + +**Run the worker:** + +```sh +# Process all pipelines: +ami worker + +# Or specify specific pipeline(s): +ami worker --pipeline moth_binary +ami worker --pipeline moth_binary --pipeline panama_moths_2024 +``` + +The worker will: + +1. Poll Antenna for jobs matching the specified pipeline(s) +2. Download images from the job queue +3. Run detection and classification +4. Post results back to Antenna +5. Repeat until queue is empty, then sleep and poll again + +**Notes:** + +- Multiple workers can run in parallel (they won't duplicate work) +- Auth token ties results to your Antenna project +- Worker continues running until interrupted (Ctrl+C) +- Safe to run multiple workers on different machines + +For more information, see the [Antenna platform documentation](https://github.com/RolnickLab/antenna). + ## Web UI demo (Gradio) A simple web UI is also available to test the inference pipeline. This is a quick way to test models on a remote server via a web browser. diff --git a/pyproject.toml b/pyproject.toml index fcb6963e..d0938613 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ gradio = "^4.41.0" [tool.pytest.ini_options] asyncio_mode = 'auto' +testpaths = ["trapdata/tests", "trapdata/**/tests"] [tool.isort] profile = "black" diff --git a/trapdata/antenna/__init__.py b/trapdata/antenna/__init__.py new file mode 100644 index 00000000..116fea16 --- /dev/null +++ b/trapdata/antenna/__init__.py @@ -0,0 +1,20 @@ +"""Antenna platform integration module. + +This module provides integration with the Antenna platform for remote image processing. +It includes: +- API client for fetching jobs and posting results +- Worker loop for continuous job processing +- Pipeline registration with Antenna projects +- Schemas for Antenna API requests/responses +- Dataset classes for streaming tasks from the API +""" + +from trapdata.antenna import client, datasets, registration, schemas, worker + +__all__ = [ + "client", + "datasets", + "registration", + "schemas", + "worker", +] diff --git a/trapdata/antenna/client.py b/trapdata/antenna/client.py new file mode 100644 index 00000000..3e500310 --- /dev/null +++ b/trapdata/antenna/client.py @@ -0,0 +1,110 @@ +"""Antenna API client for fetching jobs and posting results.""" + +import requests + +from trapdata.antenna.schemas import AntennaJobsListResponse, AntennaTaskResult +from trapdata.api.utils import get_http_session +from trapdata.common.logs import logger + + +def get_jobs( + base_url: str, + auth_token: str, + pipeline_slug: str, +) -> list[int]: + """Fetch job ids from the API for the given pipeline. + + Calls: GET {base_url}/jobs?pipeline__slug=&ids_only=1 + + Args: + base_url: Antenna API base URL (e.g., "http://localhost:8000/api/v2") + auth_token: API authentication token + pipeline_slug: Pipeline slug to filter jobs + + Returns: + List of job ids (possibly empty) on success or error. + """ + with get_http_session(auth_token) as session: + try: + url = f"{base_url.rstrip('/')}/jobs" + params = { + "pipeline__slug": pipeline_slug, + "ids_only": 1, + "incomplete_only": 1, + } + + resp = session.get(url, params=params, timeout=30) + resp.raise_for_status() + + # Parse and validate response with Pydantic + jobs_response = AntennaJobsListResponse.model_validate(resp.json()) + return [job.id for job in jobs_response.results] + except requests.RequestException as e: + logger.error(f"Failed to fetch jobs from {base_url}: {e}") + return [] + except Exception as e: + logger.error(f"Failed to parse jobs response: {e}") + return [] + + +def post_batch_results( + base_url: str, + auth_token: str, + job_id: int, + results: list[AntennaTaskResult], +) -> bool: + """ + Post batch results back to the API. + + Args: + base_url: Antenna API base URL (e.g., "http://localhost:8000/api/v2") + auth_token: API authentication token + job_id: Job ID + results: List of AntennaTaskResult objects + + Returns: + True if successful, False otherwise + """ + url = f"{base_url.rstrip('/')}/jobs/{job_id}/result/" + payload = [r.model_dump(mode="json") for r in results] + + with get_http_session(auth_token) as session: + try: + response = session.post(url, json=payload, timeout=60) + response.raise_for_status() + logger.info(f"Successfully posted {len(results)} results to {url}") + return True + except requests.RequestException as e: + logger.error(f"Failed to post results to {url}: {e}") + return False + + +def get_user_projects(base_url: str, auth_token: str) -> list[dict]: + """ + Fetch all projects the user has access to. + + Args: + base_url: Base URL for the API (should NOT include /api/v2) + auth_token: API authentication token + + Returns: + List of project dictionaries with 'id' and 'name' fields + """ + with get_http_session(auth_token) as session: + try: + url = f"{base_url.rstrip('/')}/projects/" + response = session.get(url, timeout=30) + response.raise_for_status() + data = response.json() + + projects = data.get("results", []) + if isinstance(projects, list): + return projects + else: + logger.warning( + f"Unexpected projects format from {url}: {type(projects)}" + ) + return [] + except requests.RequestException as e: + logger.error(f"Failed to fetch projects from {base_url}: {e}") + return [] diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py new file mode 100644 index 00000000..faf56b8f --- /dev/null +++ b/trapdata/antenna/datasets.py @@ -0,0 +1,280 @@ +"""Dataset classes for streaming tasks from the Antenna API.""" + +import typing +from io import BytesIO + +import requests +import torch +import torch.utils.data +import torchvision +from PIL import Image + +from trapdata.antenna.schemas import ( + AntennaPipelineProcessingTask, + AntennaTasksListResponse, +) +from trapdata.api.utils import get_http_session +from trapdata.common.logs import logger + +if typing.TYPE_CHECKING: + from trapdata.settings import Settings + + +class RESTDataset(torch.utils.data.IterableDataset): + """ + An IterableDataset that fetches tasks from a REST API endpoint and loads images. + + The dataset continuously polls the API for tasks, loads the associated images, + and yields them as PyTorch tensors along with metadata. + + IMPORTANT: This dataset assumes the API endpoint atomically removes tasks from + the queue when fetched (like RabbitMQ, SQS, Redis LPOP). This means multiple + DataLoader workers are SAFE and won't process duplicate tasks. Each worker + independently fetches different tasks from the shared queue. + + With num_workers > 0: + Worker 1: GET /tasks → receives [1,2,3,4], removed from queue + Worker 2: GET /tasks → receives [5,6,7,8], removed from queue + No duplicates, safe for parallel processing + """ + + def __init__( + self, + base_url: str, + auth_token: str, + job_id: int, + batch_size: int = 1, + image_transforms: torchvision.transforms.Compose | None = None, + ): + """ + Initialize the REST dataset. + + Args: + base_url: Base URL for the API including /api/v2 (e.g., "http://localhost:8000/api/v2") + auth_token: API authentication token + job_id: The job ID to fetch tasks for + batch_size: Number of tasks to request per batch + image_transforms: Optional transforms to apply to loaded images + """ + super().__init__() + self.base_url = base_url + self.job_id = job_id + self.batch_size = batch_size + self.image_transforms = image_transforms or torchvision.transforms.ToTensor() + + # Create persistent sessions for connection pooling + self.api_session = get_http_session(auth_token) + self.image_fetch_session = get_http_session() # No auth for external image URLs + + def __del__(self): + """Clean up HTTP sessions on dataset destruction.""" + if hasattr(self, "api_session"): + self.api_session.close() + if hasattr(self, "image_fetch_session"): + self.image_fetch_session.close() + + def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]: + """ + Fetch a batch of tasks from the REST API. + + Returns: + List of tasks (possibly empty if queue is drained) + + Raises: + requests.RequestException: If the request fails (network error, etc.) + """ + url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks" + params = {"batch": self.batch_size} + + response = self.api_session.get(url, params=params, timeout=30) + response.raise_for_status() + + # Parse and validate response with Pydantic + tasks_response = AntennaTasksListResponse.model_validate(response.json()) + return tasks_response.tasks # Empty list is valid (queue drained) + + def _load_image(self, image_url: str) -> torch.Tensor | None: + """ + Load an image from a URL and convert it to a PyTorch tensor. + + Args: + image_url: URL of the image to load + + Returns: + Image as a PyTorch tensor, or None if loading failed + """ + try: + # Use dedicated session without auth for external images + response = self.image_fetch_session.get(image_url, timeout=30) + response.raise_for_status() + image = Image.open(BytesIO(response.content)) + + # Convert to RGB if necessary + if image.mode != "RGB": + image = image.convert("RGB") + + # Apply transforms + image_tensor = self.image_transforms(image) + return image_tensor + except Exception as e: + logger.error(f"Failed to load image from {image_url}: {e}") + return None + + def __iter__(self): + """ + Iterate over tasks from the REST API. + + Yields: + Dictionary containing: + - image: PyTorch tensor of the loaded image + - reply_subject: Reply subject for the task + - batch_index: Index of the image in the batch + - job_id: Job ID + - image_id: Image ID + """ + worker_id = 0 # Initialize before try block to avoid UnboundLocalError + try: + # Get worker info for debugging + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id if worker_info else 0 + num_workers = worker_info.num_workers if worker_info else 1 + + logger.info( + f"Worker {worker_id}/{num_workers} starting iteration for job {self.job_id}" + ) + + while True: + try: + tasks = self._fetch_tasks() + except requests.RequestException as e: + # Fetch failed after retries - log and stop + logger.error( + f"Worker {worker_id}: Fetch failed after retries ({e}), stopping" + ) + break + + if not tasks: + # Queue is empty - job complete + logger.info( + f"Worker {worker_id}: No more tasks for job {self.job_id}" + ) + break + + for task in tasks: + errors = [] + # Load the image + # _, t = log_time() + image_tensor = ( + self._load_image(task.image_url) if task.image_url else None + ) + # _, t = t(f"Loaded image from {image_url}") + + if image_tensor is None: + errors.append("failed to load image") + + if errors: + logger.warning( + f"Worker {worker_id}: Errors in task for image '{task.image_id}': {', '.join(errors)}" + ) + + # Yield the data row + row = { + "image": image_tensor, + "reply_subject": task.reply_subject, + "image_id": task.image_id, + "image_url": task.image_url, + } + if errors: + row["error"] = "; ".join(errors) if errors else None + yield row + + logger.info(f"Worker {worker_id}: Iterator finished") + except Exception as e: + logger.error(f"Worker {worker_id}: Exception in iterator: {e}") + raise + + +def rest_collate_fn(batch: list[dict]) -> dict: + """ + Custom collate function that separates failed and successful items. + + Returns a dict with: + - images: Stacked tensor of valid images (only present if there are successful items) + - reply_subjects: List of reply subjects for valid images + - image_ids: List of image IDs for valid images + - image_urls: List of image URLs for valid images + - failed_items: List of dicts with metadata for failed items + + When all items in the batch have failed, the returned dict will only contain: + - reply_subjects: empty list + - image_ids: empty list + - failed_items: list of failure metadata + """ + successful = [] + failed = [] + + for item in batch: + if item["image"] is None or item.get("error"): + # Failed item + failed.append( + { + "reply_subject": item["reply_subject"], + "image_id": item["image_id"], + "image_url": item.get("image_url"), + "error": item.get("error", "Unknown error"), + } + ) + else: + # Successful item + successful.append(item) + + # Collate successful items + if successful: + result = { + "images": torch.stack([item["image"] for item in successful]), + "reply_subjects": [item["reply_subject"] for item in successful], + "image_ids": [item["image_id"] for item in successful], + "image_urls": [item.get("image_url") for item in successful], + } + else: + # Empty batch - all failed + result = { + "reply_subjects": [], + "image_ids": [], + } + + result["failed_items"] = failed + + return result + + +def get_rest_dataloader( + job_id: int, + settings: "Settings", +) -> torch.utils.data.DataLoader: + """ + Create a DataLoader that fetches tasks from Antenna API. + + Note: num_workers > 0 is SAFE here (unlike local file reading) because: + - Antenna API provides atomic task dequeue (work queue pattern) + - No shared file handles between workers + - Each worker gets different tasks automatically + - Parallel downloads improve throughput for I/O-bound work + + Args: + job_id: Job ID to fetch tasks for + settings: Settings object with antenna_api_* configuration + """ + dataset = RESTDataset( + base_url=settings.antenna_api_base_url, + auth_token=settings.antenna_api_auth_token, + job_id=job_id, + batch_size=settings.antenna_api_batch_size, + ) + + return torch.utils.data.DataLoader( + dataset, + batch_size=settings.localization_batch_size, + num_workers=settings.num_workers, + collate_fn=rest_collate_fn, + ) diff --git a/trapdata/antenna/registration.py b/trapdata/antenna/registration.py new file mode 100644 index 00000000..a78a513f --- /dev/null +++ b/trapdata/antenna/registration.py @@ -0,0 +1,179 @@ +"""Pipeline registration with Antenna projects.""" + +import socket + +import requests + +from trapdata.antenna.schemas import ( + AsyncPipelineRegistrationRequest, + AsyncPipelineRegistrationResponse, +) +from trapdata.api.api import CLASSIFIER_CHOICES, initialize_service_info +from trapdata.api.utils import get_http_session +from trapdata.common.logs import logger +from trapdata.settings import Settings, read_settings + + +def register_pipelines_for_project( + base_url: str, + auth_token: str, + project_id: int, + service_name: str, + pipeline_configs: list, +) -> tuple[bool, str]: + """ + Register all available pipelines for a specific project. + + Args: + base_url: Base URL for the API (should NOT include /api/v2) + auth_token: API authentication token + project_id: Project ID to register pipelines for + service_name: Name of the processing service + pipeline_configs: Pre-built pipeline configuration objects + + Returns: + Tuple of (success: bool, message: str) + """ + with get_http_session(auth_token=auth_token) as session: + try: + registration_request = AsyncPipelineRegistrationRequest( + processing_service_name=service_name, pipelines=pipeline_configs + ) + + url = f"{base_url.rstrip('/')}/projects/{project_id}/pipelines/" + response = session.post( + url, + json=registration_request.model_dump(mode="json"), + timeout=60, + ) + response.raise_for_status() + + result = AsyncPipelineRegistrationResponse.model_validate(response.json()) + return True, f"Created {len(result.pipelines_created)} new pipelines" + + except requests.RequestException as e: + if ( + hasattr(e, "response") + and e.response is not None + and e.response.status_code == 400 + ): + try: + error_data = e.response.json() + error_detail = error_data.get("detail", str(e)) + except Exception: + error_detail = str(e) + return False, f"Registration failed: {error_detail}" + else: + return False, f"Network error during registration: {e}" + except Exception as e: + return False, f"Unexpected error during registration: {e}" + + +def register_pipelines( + project_ids: list[int], + service_name: str, + settings: Settings | None = None, +) -> None: + """ + Register pipelines for specified projects or all accessible projects. + + Args: + project_ids: List of specific project IDs to register for. If empty, registers for all accessible projects. + service_name: Name of the processing service + settings: Settings object with antenna_api_* configuration (defaults to read_settings()) + """ + # Import here to avoid circular import + from trapdata.antenna.client import get_user_projects + + # Get settings from parameter or read from environment + if settings is None: + settings = read_settings() + + base_url = settings.antenna_api_base_url + auth_token = settings.antenna_api_auth_token + + if not auth_token: + logger.error("AMI_ANTENNA_API_AUTH_TOKEN environment variable not set") + return + + if service_name is None: + logger.error("Service name is required for registration") + return + + # Add hostname to service name + hostname = socket.gethostname() + full_service_name = f"{service_name} ({hostname})" + + # Get projects to register for + projects_to_process = [] + if project_ids: + # Use specified project IDs + projects_to_process = [ + {"id": pid, "name": f"Project {pid}"} for pid in project_ids + ] + logger.info(f"Registering pipelines for specified projects: {project_ids}") + else: + # Fetch all accessible projects + logger.info("Fetching all accessible projects...") + all_projects = get_user_projects(base_url, auth_token) + projects_to_process = all_projects + logger.info(f"Found {len(projects_to_process)} accessible projects") + + if not projects_to_process: + logger.warning("No projects found to register pipelines for") + return + + # Initialize service info once to get pipeline configurations + logger.info("Initializing pipeline configurations...") + service_info = initialize_service_info() + pipeline_configs = service_info.pipelines + logger.info(f"Generated {len(pipeline_configs)} pipeline configurations") + + # Register pipelines for each project + successful_registrations = [] + failed_registrations = [] + + logger.info(f"Available pipelines to register: {list(CLASSIFIER_CHOICES.keys())}") + + for project in projects_to_process: + project_id = project["id"] + project_name = project.get("name", f"Project {project_id}") + + logger.info( + f"Registering pipelines for project {project_id} ({project_name})..." + ) + + success, message = register_pipelines_for_project( + base_url=base_url, + auth_token=auth_token, + project_id=project_id, + service_name=full_service_name, + pipeline_configs=pipeline_configs, + ) + + if success: + successful_registrations.append((project_id, project_name, message)) + logger.info(f"✓ Project {project_id} ({project_name}): {message}") + else: + failed_registrations.append((project_id, project_name, message)) + if "Processing service already exists" in message: + logger.warning(f"⚠ Project {project_id} ({project_name}): {message}") + else: + logger.error(f"✗ Project {project_id} ({project_name}): {message}") + + # Summary report + logger.info("\n=== Registration Summary ===") + logger.info(f"Service name: {full_service_name}") + logger.info(f"Total projects processed: {len(projects_to_process)}") + logger.info(f"Successful registrations: {len(successful_registrations)}") + logger.info(f"Failed registrations: {len(failed_registrations)}") + + if successful_registrations: + logger.info("\nSuccessful registrations:") + for project_id, project_name, message in successful_registrations: + logger.info(f" - Project {project_id} ({project_name}): {message}") + + if failed_registrations: + logger.info("\nFailed registrations:") + for project_id, project_name, message in failed_registrations: + logger.info(f" - Project {project_id} ({project_name}): {message}") diff --git a/trapdata/antenna/schemas.py b/trapdata/antenna/schemas.py new file mode 100644 index 00000000..fa83ad73 --- /dev/null +++ b/trapdata/antenna/schemas.py @@ -0,0 +1,87 @@ +"""Pydantic schemas for Antenna API requests and responses.""" + +import pydantic + +from trapdata.api.schemas import PipelineConfigResponse, PipelineResultsResponse + +# @TODO move more schemas here that are Antenna-specific from api/schemas.py + + +class AntennaPipelineProcessingTask(pydantic.BaseModel): + """ + A task representing a single image or detection to be processed in an async pipeline. + """ + + id: str + image_id: str + image_url: str + reply_subject: str | None = None # The NATS subject to send the result to + # TODO: Do we need these? + # detections: list[DetectionRequest] | None = None + # config: PipelineRequestConfigParameters | dict | None = None + + +class AntennaJobListItem(pydantic.BaseModel): + """A single job item from the Antenna jobs list API response.""" + + id: int + + +class AntennaJobsListResponse(pydantic.BaseModel): + """Response from Antenna API GET /api/v2/jobs with ids_only=1.""" + + results: list[AntennaJobListItem] + + +class AntennaTasksListResponse(pydantic.BaseModel): + """Response from Antenna API GET /api/v2/jobs/{job_id}/tasks.""" + + tasks: list[AntennaPipelineProcessingTask] + + +class AntennaTaskResultError(pydantic.BaseModel): + """Error result for a single Antenna task that failed to process.""" + + error: str + image_id: str | None = None + + +class AntennaTaskResult(pydantic.BaseModel): + """Result for a single Antenna task, either success or error.""" + + reply_subject: str | None = None + result: PipelineResultsResponse | AntennaTaskResultError + + +class AntennaTaskResults(pydantic.BaseModel): + """Batch of task results to post back to Antenna API.""" + + results: list[AntennaTaskResult] = pydantic.Field(default_factory=list) + + +class AsyncPipelineRegistrationRequest(pydantic.BaseModel): + """ + Request to register pipelines from an async processing service + """ + + processing_service_name: str + pipelines: list[PipelineConfigResponse] = [] + + +class AsyncPipelineRegistrationResponse(pydantic.BaseModel): + """ + Response from registering pipelines with a project. + """ + + pipelines_created: list[str] = pydantic.Field( + default_factory=list, + description="List of pipeline slugs that were created", + ) + pipelines_updated: list[str] = pydantic.Field( + default_factory=list, + description="List of pipeline slugs that were updated", + ) + processing_service_id: int | None = pydantic.Field( + default=None, + description="ID of the processing service that was created or updated", + ) diff --git a/trapdata/antenna/service.conf b/trapdata/antenna/service.conf new file mode 100644 index 00000000..0ef55b6f --- /dev/null +++ b/trapdata/antenna/service.conf @@ -0,0 +1,16 @@ +# Example supervisord configuration for AMI Antenna worker +# to run it as a continuous background service +[program:ami-antenna-worker] +directory=/home/debian/ami-data-companion +command=/home/debian/miniconda3/bin/ami worker run +autostart=true +autorestart=true +# stopsignal=KILL +stopasgroup=true +killasgroup=true +stderr_logfile=/var/log/ami.err.log +stdout_logfile=/var/log/ami.out.log +# process_name=%(program_name)s_%(process_num)02d +environment=HOME="/home/debian",USER="debian" +user=debian + diff --git a/trapdata/antenna/tests/__init__.py b/trapdata/antenna/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trapdata/antenna/tests/antenna_api_server.py b/trapdata/antenna/tests/antenna_api_server.py new file mode 100644 index 00000000..18eafcd8 --- /dev/null +++ b/trapdata/antenna/tests/antenna_api_server.py @@ -0,0 +1,188 @@ +"""Mock Antenna API server for integration testing. + +This module provides a FastAPI application that mocks the Antenna API endpoints +used by the worker. It allows tests to validate the API contract without +requiring an actual Antenna server. +""" + +from fastapi import FastAPI, HTTPException + +from trapdata.antenna.schemas import ( + AntennaJobListItem, + AntennaJobsListResponse, + AntennaPipelineProcessingTask, + AntennaTaskResult, + AntennaTasksListResponse, + AsyncPipelineRegistrationRequest, + AsyncPipelineRegistrationResponse, +) + +app = FastAPI() + +# State management for tests +_jobs_queue: dict[int, list[AntennaPipelineProcessingTask]] = {} +_posted_results: dict[int, list[AntennaTaskResult]] = {} +_projects: list[dict] = [] +_registered_pipelines: dict[int, list[str]] = {} # project_id -> pipeline slugs + + +@app.get("/api/v2/jobs") +def get_jobs(pipeline__slug: str, ids_only: int, incomplete_only: int): + """Return available job IDs. + + Args: + pipeline__slug: Pipeline slug filter + ids_only: If 1, return only job IDs + incomplete_only: If 1, return only incomplete jobs + + Returns: + AntennaJobsListResponse with list of job IDs + """ + # Return all jobs in queue (for testing, we return all registered jobs) + job_ids = list(_jobs_queue.keys()) + results = [AntennaJobListItem(id=job_id) for job_id in job_ids] + return AntennaJobsListResponse(results=results) + + +@app.get("/api/v2/jobs/{job_id}/tasks") +def get_tasks(job_id: int, batch: int): + """Return batch of tasks (atomically remove from queue). + + Args: + job_id: Job ID to fetch tasks for + batch: Number of tasks to return + + Returns: + AntennaTasksListResponse with batch of tasks + """ + if job_id not in _jobs_queue: + return AntennaTasksListResponse(tasks=[]) + + # Get up to `batch` tasks and remove them from queue + tasks = _jobs_queue[job_id][:batch] + _jobs_queue[job_id] = _jobs_queue[job_id][batch:] + + return AntennaTasksListResponse(tasks=tasks) + + +@app.post("/api/v2/jobs/{job_id}/result/") +def post_results(job_id: int, payload: list[dict]): + """Store posted results for test validation. + + Args: + job_id: Job ID to post results for + payload: List of AntennaTaskResult dicts + + Returns: + Success status + """ + if job_id not in _posted_results: + _posted_results[job_id] = [] + + # Parse each result dict into AntennaTaskResult + for result_dict in payload: + task_result = AntennaTaskResult(**result_dict) + _posted_results[job_id].append(task_result) + + return {"status": "ok"} + + +@app.get("/api/v2/projects/") +def get_projects(): + """Return list of projects the user has access to. + + Returns: + Paginated response with list of projects + """ + return {"results": _projects} + + +@app.post("/api/v2/projects/{project_id}/pipelines/") +def register_pipelines(project_id: int, payload: dict): + """Register pipelines for a project. + + Args: + project_id: Project ID to register pipelines for + payload: AsyncPipelineRegistrationRequest as dict + + Returns: + AsyncPipelineRegistrationResponse + """ + # Validate request + request = AsyncPipelineRegistrationRequest(**payload) + + # Check if project exists + project_ids = [p["id"] for p in _projects] + if project_id not in project_ids: + raise HTTPException(status_code=404, detail="Project not found") + + # Track registered pipelines + if project_id not in _registered_pipelines: + _registered_pipelines[project_id] = [] + + created = [] + for pipeline in request.pipelines: + if pipeline.slug not in _registered_pipelines[project_id]: + _registered_pipelines[project_id].append(pipeline.slug) + created.append(pipeline.slug) + + return AsyncPipelineRegistrationResponse( + pipelines_created=created, + pipelines_updated=[], + processing_service_id=1, + ) + + +# Test helper methods + + +def setup_job(job_id: int, tasks: list[AntennaPipelineProcessingTask]): + """Populate job queue for testing. + + Args: + job_id: Job ID to setup + tasks: List of tasks to add to the queue + """ + _jobs_queue[job_id] = tasks.copy() + + +def get_posted_results(job_id: int) -> list[AntennaTaskResult]: + """Retrieve results posted by worker. + + Args: + job_id: Job ID to get results for + + Returns: + List of posted task results + """ + return _posted_results.get(job_id, []) + + +def setup_projects(projects: list[dict]): + """Setup projects for testing. + + Args: + projects: List of project dicts with 'id' and 'name' fields + """ + _projects.clear() + _projects.extend(projects) + + +def get_registered_pipelines(project_id: int) -> list[str]: + """Get list of pipeline slugs registered for a project. + + Args: + project_id: Project ID to get pipelines for + + Returns: + List of pipeline slugs + """ + return _registered_pipelines.get(project_id, []) + + +def reset(): + """Clear all state between tests.""" + _jobs_queue.clear() + _posted_results.clear() + _projects.clear() + _registered_pipelines.clear() diff --git a/trapdata/antenna/tests/test_worker.py b/trapdata/antenna/tests/test_worker.py new file mode 100644 index 00000000..4a83958a --- /dev/null +++ b/trapdata/antenna/tests/test_worker.py @@ -0,0 +1,488 @@ +"""Integration tests for the REST worker and related utilities. + +These tests validate the Antenna API contract and run real ML inference through +the worker's unique code path (RESTDataset → rest_collate_fn → batch processing). +Only external service dependencies are mocked - ML models and image loading are real. +""" + +import pathlib +from unittest import TestCase +from unittest.mock import MagicMock + +import torch +from fastapi.testclient import TestClient + +from trapdata.antenna.client import get_jobs +from trapdata.antenna.datasets import RESTDataset, rest_collate_fn +from trapdata.antenna.registration import register_pipelines_for_project +from trapdata.antenna.schemas import ( + AntennaPipelineProcessingTask, + AntennaTaskResult, + AntennaTaskResultError, + PipelineConfigResponse, +) +from trapdata.antenna.tests import antenna_api_server +from trapdata.antenna.tests.antenna_api_server import app as antenna_app +from trapdata.antenna.worker import _process_job +from trapdata.api.schemas import PipelineResultsResponse +from trapdata.api.tests.image_server import StaticFileTestServer +from trapdata.api.tests.utils import get_test_image_urls, patch_antenna_api_requests +from trapdata.tests import TEST_IMAGES_BASE_PATH + +# --------------------------------------------------------------------------- +# TestRestCollateFn - Unit tests for collation logic +# --------------------------------------------------------------------------- + + +class TestRestCollateFn(TestCase): + """Tests for rest_collate_fn which separates successful/failed items.""" + + def test_all_successful(self): + batch = [ + { + "image": torch.rand(3, 64, 64), + "reply_subject": "subj1", + "image_id": "img1", + "image_url": "http://example.com/1.jpg", + }, + { + "image": torch.rand(3, 64, 64), + "reply_subject": "subj2", + "image_id": "img2", + "image_url": "http://example.com/2.jpg", + }, + ] + result = rest_collate_fn(batch) + + assert "images" in result + assert result["images"].shape == (2, 3, 64, 64) + assert result["image_ids"] == ["img1", "img2"] + assert result["reply_subjects"] == ["subj1", "subj2"] + assert result["failed_items"] == [] + + def test_all_failed(self): + batch = [ + { + "image": None, + "reply_subject": "subj1", + "image_id": "img1", + "image_url": "http://example.com/1.jpg", + "error": "download failed", + }, + { + "image": None, + "reply_subject": "subj2", + "image_id": "img2", + "image_url": "http://example.com/2.jpg", + "error": "timeout", + }, + ] + result = rest_collate_fn(batch) + + assert "images" not in result + assert result["image_ids"] == [] + assert result["reply_subjects"] == [] + assert len(result["failed_items"]) == 2 + assert result["failed_items"][0]["image_id"] == "img1" + assert result["failed_items"][1]["error"] == "timeout" + + def test_mixed(self): + batch = [ + { + "image": torch.rand(3, 64, 64), + "reply_subject": "subj1", + "image_id": "img1", + "image_url": "http://example.com/1.jpg", + }, + { + "image": None, + "reply_subject": "subj2", + "image_id": "img2", + "image_url": "http://example.com/2.jpg", + "error": "404", + }, + ] + result = rest_collate_fn(batch) + + assert result["images"].shape == (1, 3, 64, 64) + assert result["image_ids"] == ["img1"] + assert len(result["failed_items"]) == 1 + assert result["failed_items"][0]["image_id"] == "img2" + + +# --------------------------------------------------------------------------- +# TestRESTDatasetIntegration - Integration tests with real image loading +# --------------------------------------------------------------------------- + + +class TestRESTDatasetIntegration(TestCase): + """Integration tests for RESTDataset that fetch tasks and load real images.""" + + @classmethod + def setUpClass(cls): + # Setup file server for test images + cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH) + cls.file_server = StaticFileTestServer(cls.test_images_dir) + cls.file_server.start() # Start server and keep it running for all tests + + # Setup mock Antenna API + cls.antenna_client = TestClient(antenna_app) + + @classmethod + def tearDownClass(cls): + cls.file_server.stop() + + def setUp(self): + # Reset state between tests + antenna_api_server.reset() + + def _make_dataset(self, job_id: int = 42, batch_size: int = 2) -> RESTDataset: + """Create a RESTDataset pointing to the mock API.""" + return RESTDataset( + base_url="http://testserver/api/v2", + job_id=job_id, + batch_size=batch_size, + auth_token="test-token", + ) + + def test_multiple_batches(self): + """Dataset fetches multiple batches until queue is empty.""" + # Setup job with 3 images (all available in vermont dir), batch size 2 + image_urls = get_test_image_urls( + self.file_server, self.test_images_dir, subdir="vermont", num=3 + ) + tasks = [ + AntennaPipelineProcessingTask( + id=f"task_{i}", + image_id=f"img_{i}", + image_url=url, + reply_subject=f"reply_{i}", + ) + for i, url in enumerate(image_urls) + ] + antenna_api_server.setup_job(job_id=4, tasks=tasks) + + with patch_antenna_api_requests(self.antenna_client): + dataset = self._make_dataset(job_id=4, batch_size=2) + rows = list(dataset) + + # Should get all 3 images (batch1: 2 images, batch2: 1 image) + assert len(rows) == 3 + assert all(r["image"] is not None for r in rows) + + +# --------------------------------------------------------------------------- +# TestGetJobsIntegration - Integration tests for job fetching +# --------------------------------------------------------------------------- + + +class TestGetJobsIntegration(TestCase): + """Integration tests for get_jobs() with mock Antenna API.""" + + @classmethod + def setUpClass(cls): + cls.antenna_client = TestClient(antenna_app) + + def setUp(self): + antenna_api_server.reset() + + def test_returns_job_ids(self): + """Successfully fetches list of job IDs.""" + # Setup jobs in queue + antenna_api_server.setup_job(10, []) + antenna_api_server.setup_job(20, []) + antenna_api_server.setup_job(30, []) + + with patch_antenna_api_requests(self.antenna_client): + result = get_jobs("http://testserver/api/v2", "test-token", "moths_2024") + + assert result == [10, 20, 30] + + +# --------------------------------------------------------------------------- +# TestProcessJobIntegration - Integration tests with real ML inference +# --------------------------------------------------------------------------- + + +class TestProcessJobIntegration(TestCase): + """Integration tests for _process_job() with real detector and classifier.""" + + @classmethod + def setUpClass(cls): + cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH) + cls.file_server = StaticFileTestServer(cls.test_images_dir) + cls.file_server.start() # Start server and keep it running for all tests + cls.antenna_client = TestClient(antenna_app) + + @classmethod + def tearDownClass(cls): + cls.file_server.stop() + + def setUp(self): + antenna_api_server.reset() + + def _make_settings(self): + """Create mock settings for worker.""" + settings = MagicMock() + settings.antenna_api_base_url = "http://testserver/api/v2" + settings.antenna_api_auth_token = "test-token" + settings.antenna_api_batch_size = 2 + settings.num_workers = 0 # Disable multiprocessing for tests + settings.localization_batch_size = 2 # Real integer for batch processing + return settings + + def test_empty_queue(self): + """No tasks in queue → returns False.""" + antenna_api_server.setup_job(job_id=100, tasks=[]) + + with patch_antenna_api_requests(self.antenna_client): + result = _process_job( + "quebec_vermont_moths_2023", 100, self._make_settings() + ) + + assert result is False + + def test_processes_batch_with_real_inference(self): + """Worker fetches tasks, loads images, runs ML, posts results.""" + # Setup job with 2 test images + image_urls = get_test_image_urls( + self.file_server, self.test_images_dir, subdir="vermont", num=2 + ) + tasks = [ + AntennaPipelineProcessingTask( + id=f"task_{i}", + image_id=f"img_{i}", + image_url=url, + reply_subject=f"reply_{i}", + ) + for i, url in enumerate(image_urls) + ] + antenna_api_server.setup_job(job_id=101, tasks=tasks) + + # Run worker + with patch_antenna_api_requests(self.antenna_client): + result = _process_job( + "quebec_vermont_moths_2023", 101, self._make_settings() + ) + + # Validate processing succeeded + assert result is True + + # Validate results were posted + posted_results = antenna_api_server.get_posted_results(101) + assert len(posted_results) == 2 + + # Validate schema compliance + for task_result in posted_results: + assert isinstance(task_result, AntennaTaskResult) + assert isinstance(task_result.result, PipelineResultsResponse) + + # Validate structure + response = task_result.result + assert response.pipeline == "quebec_vermont_moths_2023" + assert response.total_time > 0 + assert len(response.source_images) == 1 + assert len(response.detections) >= 0 # May be 0 if no moths + + def test_handles_failed_items(self): + """Failed image downloads produce AntennaTaskResultError.""" + tasks = [ + AntennaPipelineProcessingTask( + id="task_fail", + image_id="img_fail", + image_url="http://invalid-url.test/image.jpg", + reply_subject="reply_fail", + ) + ] + antenna_api_server.setup_job(job_id=102, tasks=tasks) + + with patch_antenna_api_requests(self.antenna_client): + _process_job("quebec_vermont_moths_2023", 102, self._make_settings()) + + posted_results = antenna_api_server.get_posted_results(102) + assert len(posted_results) == 1 + assert isinstance(posted_results[0].result, AntennaTaskResultError) + assert posted_results[0].result.error # Error message should not be empty + + def test_mixed_batch_success_and_failures(self): + """Batch with some successful and some failed images.""" + # One valid image, one invalid + valid_url = get_test_image_urls( + self.file_server, self.test_images_dir, subdir="vermont", num=1 + )[0] + + tasks = [ + AntennaPipelineProcessingTask( + id="task_good", + image_id="img_good", + image_url=valid_url, + reply_subject="reply_good", + ), + AntennaPipelineProcessingTask( + id="task_bad", + image_id="img_bad", + image_url="http://invalid-url.test/bad.jpg", + reply_subject="reply_bad", + ), + ] + antenna_api_server.setup_job(job_id=103, tasks=tasks) + + with patch_antenna_api_requests(self.antenna_client): + result = _process_job( + "quebec_vermont_moths_2023", 103, self._make_settings() + ) + + assert result is True + posted_results = antenna_api_server.get_posted_results(103) + assert len(posted_results) == 2 + + # One success, one error + success_results = [ + r for r in posted_results if isinstance(r.result, PipelineResultsResponse) + ] + error_results = [ + r for r in posted_results if isinstance(r.result, AntennaTaskResultError) + ] + assert len(success_results) == 1 + assert len(error_results) == 1 + + +# --------------------------------------------------------------------------- +# TestWorkerEndToEnd - Full workflow integration tests +# --------------------------------------------------------------------------- + + +class TestWorkerEndToEnd(TestCase): + """End-to-end integration tests for complete worker workflow.""" + + @classmethod + def setUpClass(cls): + cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH) + cls.file_server = StaticFileTestServer(cls.test_images_dir) + cls.file_server.start() # Start server and keep it running for all tests + cls.antenna_client = TestClient(antenna_app) + + @classmethod + def tearDownClass(cls): + cls.file_server.stop() + + def setUp(self): + antenna_api_server.reset() + + def _make_settings(self): + settings = MagicMock() + settings.antenna_api_base_url = "http://testserver/api/v2" + settings.antenna_api_auth_token = "test-token" + settings.antenna_api_batch_size = 2 + settings.num_workers = 0 + settings.localization_batch_size = 2 # Real integer for batch processing + return settings + + def test_full_workflow_with_real_inference(self): + """ + Complete workflow: register → fetch jobs → fetch tasks → load images → + run detection → run classification → post results. + """ + pipeline_slug = "quebec_vermont_moths_2023" + + # Setup project and job with 2 test images + antenna_api_server.setup_projects([{"id": 1, "name": "Test Project"}]) + image_urls = get_test_image_urls( + self.file_server, self.test_images_dir, subdir="vermont", num=2 + ) + tasks = [ + AntennaPipelineProcessingTask( + id=f"task_{i}", + image_id=f"img_{i}", + image_url=url, + reply_subject=f"reply_{i}", + ) + for i, url in enumerate(image_urls) + ] + antenna_api_server.setup_job(job_id=200, tasks=tasks) + + with patch_antenna_api_requests(self.antenna_client): + # Step 1: Register pipeline + pipeline_configs = [ + PipelineConfigResponse( + name="Vermont Moths", slug=pipeline_slug, version=1 + ) + ] + success, _ = register_pipelines_for_project( + base_url="http://testserver/api/v2", + auth_token="test-token", + project_id=1, + service_name="Test Worker", + pipeline_configs=pipeline_configs, + ) + assert success is True + + # Step 2: Get jobs + job_ids = get_jobs( + "http://testserver/api/v2", + "test-token", + pipeline_slug, + ) + assert 200 in job_ids + + # Step 3: Process job + result = _process_job(pipeline_slug, 200, self._make_settings()) + assert result is True + + # Step 4: Validate results posted + posted_results = antenna_api_server.get_posted_results(200) + assert len(posted_results) == 2 + + # Validate all results are valid + for task_result in posted_results: + assert isinstance(task_result, AntennaTaskResult) + assert task_result.reply_subject is not None + + # Should be success results + assert isinstance(task_result.result, PipelineResultsResponse) + response = task_result.result + + # Validate pipeline response structure + assert response.pipeline == "quebec_vermont_moths_2023" + assert response.total_time > 0 + assert len(response.source_images) == 1 + + # Validate detections structure (may be empty if no moths) + assert isinstance(response.detections, list) + if response.detections: + detection = response.detections[0] + assert detection.bbox is not None + assert detection.source_image_id is not None + + def test_multiple_batches_processed(self): + """Job with more tasks than batch size processes in multiple batches.""" + # Setup job with 3 images (all available in vermont dir), batch size 2 + image_urls = get_test_image_urls( + self.file_server, self.test_images_dir, subdir="vermont", num=3 + ) + tasks = [ + AntennaPipelineProcessingTask( + id=f"task_{i}", + image_id=f"img_{i}", + image_url=url, + reply_subject=f"reply_{i}", + ) + for i, url in enumerate(image_urls) + ] + antenna_api_server.setup_job(job_id=201, tasks=tasks) + + with patch_antenna_api_requests(self.antenna_client): + result = _process_job( + "quebec_vermont_moths_2023", 201, self._make_settings() + ) + + assert result is True + + # All 3 results should be posted (batch1: 2, batch2: 1) + posted_results = antenna_api_server.get_posted_results(201) + assert len(posted_results) == 3 + + # All should be successful + assert all( + isinstance(r.result, PipelineResultsResponse) for r in posted_results + ) diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py new file mode 100644 index 00000000..2fbf3b54 --- /dev/null +++ b/trapdata/antenna/worker.py @@ -0,0 +1,252 @@ +"""Worker loop for processing jobs from Antenna API.""" + +import datetime +import time + +import numpy as np +import torch + +from trapdata.antenna.client import get_jobs, post_batch_results +from trapdata.antenna.datasets import get_rest_dataloader +from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError +from trapdata.api.api import CLASSIFIER_CHOICES +from trapdata.api.models.localization import APIMothDetector +from trapdata.api.schemas import ( + DetectionResponse, + PipelineResultsResponse, + SourceImageResponse, +) +from trapdata.common.logs import logger +from trapdata.common.utils import log_time +from trapdata.settings import Settings, read_settings + +SLEEP_TIME_SECONDS = 5 + + +def run_worker(pipelines: list[str]): + """Run the worker to process images from the REST API queue.""" + settings = read_settings() + + # Validate auth token + if not settings.antenna_api_auth_token: + raise ValueError( + "AMI_ANTENNA_API_AUTH_TOKEN environment variable must be set. " + "Get your auth token from your Antenna project settings." + ) + + while True: + # TODO CGJS: Support pulling and prioritizing single image tasks, which are used in interactive testing + # These should probably come from a dedicated endpoint and should preempt batch jobs under the assumption that they + # would run on the same GPU. + any_jobs = False + for pipeline in pipelines: + logger.info(f"Checking for jobs for pipeline {pipeline}") + jobs = get_jobs( + base_url=settings.antenna_api_base_url, + auth_token=settings.antenna_api_auth_token, + pipeline_slug=pipeline, + ) + for job_id in jobs: + logger.info(f"Processing job {job_id} with pipeline {pipeline}") + try: + any_work_done = _process_job( + pipeline=pipeline, + job_id=job_id, + settings=settings, + ) + any_jobs = any_jobs or any_work_done + except Exception as e: + logger.error( + f"Failed to process job {job_id} with pipeline {pipeline}: {e}", + exc_info=True, + ) + # Continue to next job rather than crashing the worker + + if not any_jobs: + logger.info(f"No jobs found, sleeping for {SLEEP_TIME_SECONDS} seconds") + time.sleep(SLEEP_TIME_SECONDS) + + +@torch.no_grad() +def _process_job( + pipeline: str, + job_id: int, + settings: Settings, +) -> bool: + """Run the worker to process images from the REST API queue. + + Args: + pipeline: Pipeline name to use for processing (e.g., moth_binary, panama_moths_2024) + job_id: Job ID to process + settings: Settings object with antenna_api_* configuration + Returns: + True if any work was done, False otherwise + """ + did_work = False + loader = get_rest_dataloader(job_id=job_id, settings=settings) + classifier = None + detector = None + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + items = 0 + + total_detection_time = 0.0 + total_classification_time = 0.0 + total_save_time = 0.0 + total_dl_time = 0.0 + all_detections = [] + _, t = log_time() + + for i, batch in enumerate(loader): + dt, t = t("Finished loading batch") + total_dl_time += dt + if not batch: + logger.warning(f"Batch {i + 1} is empty, skipping") + continue + + # Defer instantiation of detector and classifier until we have data + if not classifier: + classifier_class = CLASSIFIER_CHOICES[pipeline] + classifier = classifier_class(source_images=[], detections=[]) + detector = APIMothDetector([]) + assert detector is not None, "Detector not initialized" + assert classifier is not None, "Classifier not initialized" + detector.reset([]) + did_work = True + + # Extract data from dictionary batch + images = batch.get("images", []) + image_ids = batch.get("image_ids", []) + reply_subjects = batch.get("reply_subjects", [None] * len(images)) + image_urls = batch.get("image_urls", [None] * len(images)) + + # Validate all arrays have same length before zipping + if len(image_ids) != len(images): + raise ValueError( + f"Length mismatch: image_ids ({len(image_ids)}) != images ({len(images)})" + ) + if len(image_ids) != len(reply_subjects) or len(image_ids) != len(image_urls): + raise ValueError( + f"Length mismatch: image_ids ({len(image_ids)}), " + f"reply_subjects ({len(reply_subjects)}), image_urls ({len(image_urls)})" + ) + + # Track start time for this batch + batch_start_time = datetime.datetime.now() + + logger.info(f"Processing batch {i + 1}") + # output is dict of "boxes", "labels", "scores" + batch_output = [] + if len(images) > 0: + batch_output = detector.predict_batch(images) + + items += len(batch_output) + logger.info(f"Total items processed so far: {items}") + batch_output = list(detector.post_process_batch(batch_output)) + + # Convert image_ids to list if needed + if isinstance(image_ids, (np.ndarray, torch.Tensor)): + image_ids = image_ids.tolist() + + # TODO CGJS: Add seconds per item calculation for both detector and classifier + detector.save_results( + item_ids=image_ids, + batch_output=batch_output, + seconds_per_item=0, + ) + dt, t = t("Finished detection") + total_detection_time += dt + + # Group detections by image_id + image_detections: dict[str, list[DetectionResponse]] = { + img_id: [] for img_id in image_ids + } + image_tensors = dict(zip(image_ids, images, strict=True)) + + classifier.reset(detector.results) + + for idx, dresp in enumerate(detector.results): + image_tensor = image_tensors[dresp.source_image_id] + bbox = dresp.bbox + # crop the image tensor using the bbox + crop = image_tensor[ + :, int(bbox.y1) : int(bbox.y2), int(bbox.x1) : int(bbox.x2) + ] + crop = crop.unsqueeze(0) # add batch dimension + classifier_out = classifier.predict_batch(crop) + classifier_out = classifier.post_process_batch(classifier_out) + detection = classifier.update_detection_classification( + seconds_per_item=0, + image_id=dresp.source_image_id, + detection_idx=idx, + predictions=classifier_out[0], + ) + image_detections[dresp.source_image_id].append(detection) + all_detections.append(detection) + + ct, t = t("Finished classification") + total_classification_time += ct + + # Calculate batch processing time + batch_end_time = datetime.datetime.now() + batch_elapsed = (batch_end_time - batch_start_time).total_seconds() + + # Post results back to the API with PipelineResponse for each image + batch_results: list[AntennaTaskResult] = [] + for reply_subject, image_id, image_url in zip( + reply_subjects, image_ids, image_urls, strict=True + ): + # Create SourceImageResponse for this image + source_image = SourceImageResponse(id=image_id, url=image_url) + + # Create PipelineResultsResponse + pipeline_response = PipelineResultsResponse( + pipeline=pipeline, + source_images=[source_image], + detections=image_detections[image_id], + total_time=batch_elapsed / len(image_ids), # Approximate time per image + ) + + batch_results.append( + AntennaTaskResult( + reply_subject=reply_subject, + result=pipeline_response, + ) + ) + failed_items = batch.get("failed_items") + if failed_items: + for failed_item in failed_items: + batch_results.append( + AntennaTaskResult( + reply_subject=failed_item.get("reply_subject"), + result=AntennaTaskResultError( + error=failed_item.get("error", "Unknown error"), + image_id=failed_item.get("image_id"), + ), + ) + ) + + success = post_batch_results( + settings.antenna_api_base_url, + settings.antenna_api_auth_token, + job_id, + batch_results, + ) + st, t = t("Finished posting results") + + if not success: + error_msg = ( + f"Failed to post {len(batch_results)} results for job {job_id} to " + f"{settings.antenna_api_base_url}. Batch processing data lost." + ) + logger.error(error_msg) + raise RuntimeError(error_msg) + + total_save_time += st + + logger.info( + f"Done, detections: {len(all_detections)}. Detecting time: {total_detection_time}, " + f"classification time: {total_classification_time}, dl time: {total_dl_time}, save time: {total_save_time}" + ) + return did_work diff --git a/trapdata/api/api.py b/trapdata/api/api.py index cc3044c5..47f34fec 100644 --- a/trapdata/api/api.py +++ b/trapdata/api/api.py @@ -5,6 +5,7 @@ import enum import time +from contextlib import asynccontextmanager import fastapi import pydantic @@ -36,7 +37,18 @@ from .schemas import PipelineResultsResponse as PipelineResponse_ from .schemas import ProcessingServiceInfoResponse, SourceImage, SourceImageResponse -app = fastapi.FastAPI() + +@asynccontextmanager +async def lifespan(app: fastapi.FastAPI): + # cache the service info to be built only once at startup + app.state.service_info = initialize_service_info() + logger.info("Initialized service info") + yield + # Shutdown event: Clean up resources (if necessary) + logger.info("Shutting down API") + + +app = fastapi.FastAPI(lifespan=lifespan) app.add_middleware(GZipMiddleware) @@ -91,7 +103,6 @@ def make_category_map_response( def make_algorithm_response( model: APIMothDetector | APIMothClassifier, ) -> AlgorithmConfigResponse: - category_map = make_category_map_response(model) if model.category_map else None return AlgorithmConfigResponse( name=model.name, @@ -157,13 +168,6 @@ def make_pipeline_config_response( ) -# @TODO This requires loading all models into memory! Can we avoid this? -PIPELINE_CONFIGS = [ - make_pipeline_config_response(classifier_class, slug=key) - for key, classifier_class in CLASSIFIER_CHOICES.items() -] - - class PipelineRequest(PipelineRequest_): pipeline: PipelineChoice = pydantic.Field( description=PipelineRequest_.model_fields["pipeline"].description, @@ -313,17 +317,7 @@ async def process(data: PipelineRequest) -> PipelineResponse: @app.get("/info", tags=["services"]) async def info() -> ProcessingServiceInfoResponse: - info = ProcessingServiceInfoResponse( - name="Antenna Inference API", - description=( - "The primary endpoint for processing images for the Antenna platform. " - "This API provides access to multiple detection and classification " - "algorithms by multiple labs for processing images of moths." - ), - pipelines=PIPELINE_CONFIGS, - # algorithms=list(algorithm_choices.values()), - ) - return info + return app.state.service_info # Check if the server is online @@ -361,6 +355,26 @@ async def readyz(): # pass +def initialize_service_info() -> ProcessingServiceInfoResponse: + # @TODO This requires loading all models into memory! Can we avoid this? + pipeline_configs = [ + make_pipeline_config_response(classifier_class, slug=key) + for key, classifier_class in CLASSIFIER_CHOICES.items() + ] + + _info = ProcessingServiceInfoResponse( + name="Antenna Inference API", + description=( + "The primary endpoint for processing images for the Antenna platform. " + "This API provides access to multiple detection and classification " + "algorithms by multiple labs for processing images of moths." + ), + pipelines=pipeline_configs, + # algorithms=list(algorithm_choices.values()), + ) + return _info + + if __name__ == "__main__": import uvicorn diff --git a/trapdata/api/models/classification.py b/trapdata/api/models/classification.py index 482c4ac3..e604f3c8 100644 --- a/trapdata/api/models/classification.py +++ b/trapdata/api/models/classification.py @@ -54,6 +54,10 @@ def __init__( "detections" ) + def reset(self, detections: typing.Iterable[DetectionResponse]): + self.detections = list(detections) + self.results = [] + def get_dataset(self): return ClassificationImageDataset( source_images=self.source_images, @@ -117,19 +121,12 @@ def save_results( for image_id, detection_idx, predictions in zip( image_ids, detection_idxes, batch_output ): - detection = self.detections[detection_idx] - assert detection.source_image_id == image_id - - classification = ClassificationResponse( - classification=self.get_best_label(predictions), - scores=predictions.scores, - logits=predictions.logit, - inference_time=seconds_per_item, - algorithm=AlgorithmReference(name=self.name, key=self.get_key()), - timestamp=datetime.datetime.now(), - terminal=self.terminal, + self.update_detection_classification( + seconds_per_item, + image_id, + detection_idx, + predictions, ) - self.update_classification(detection, classification) self.results = self.detections logger.info(f"Saving {len(self.results)} detections with classifications") @@ -149,6 +146,32 @@ def update_classification( f"Total classifications: {len(detection.classifications)}" ) + def update_detection_classification( + self, + seconds_per_item: float, + image_id: str, + detection_idx: int, + predictions: ClassifierResult, + ) -> DetectionResponse: + detection = self.detections[detection_idx] + if detection.source_image_id != image_id: + raise ValueError( + f"Detection index {detection_idx} has mismatched image_id: " + f"expected '{image_id}', got '{detection.source_image_id}'" + ) + + classification = ClassificationResponse( + classification=self.get_best_label(predictions), + scores=predictions.scores, + logits=predictions.logit, + inference_time=seconds_per_item, + algorithm=AlgorithmReference(name=self.name, key=self.get_key()), + timestamp=datetime.datetime.now(), + terminal=self.terminal, + ) + self.update_classification(detection, classification) + return detection + def run(self) -> list[DetectionResponse]: logger.info( f"Starting {self.__class__.__name__} run with {len(self.results)} " diff --git a/trapdata/api/models/localization.py b/trapdata/api/models/localization.py index 600fc9f7..9ec1acd5 100644 --- a/trapdata/api/models/localization.py +++ b/trapdata/api/models/localization.py @@ -1,4 +1,3 @@ -import concurrent.futures import datetime import typing @@ -17,6 +16,10 @@ def __init__(self, source_images: typing.Iterable[SourceImage], *args, **kwargs) self.results: list[DetectionResponse] = [] super().__init__(*args, **kwargs) + def reset(self, source_images: typing.Iterable[SourceImage]): + self.source_images = source_images + self.results = [] + def get_dataset(self): return LocalizationImageDataset( self.source_images, self.get_transforms(), batch_size=self.batch_size @@ -43,15 +46,9 @@ def save_detection(image_id, coords): ) return detection - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [] - for image_id, image_output in zip(item_ids, batch_output): - for coords in image_output: - future = executor.submit(save_detection, image_id, coords) - futures.append(future) - - for future in concurrent.futures.as_completed(futures): - detection = future.result() + for image_id, image_output in zip(item_ids, batch_output): + for coords in image_output: + detection = save_detection(image_id, coords) detections.append(detection) self.results += detections diff --git a/trapdata/api/service.conf b/trapdata/api/service.conf index 745b39c1..f592c340 100644 --- a/trapdata/api/service.conf +++ b/trapdata/api/service.conf @@ -1,3 +1,5 @@ +# Example supervisord configuration for AMI API server +# to run it as a continuous background service [program:ami] directory=/home/debian/ami-data-companion command=/home/debian/miniconda3/bin/ami api diff --git a/trapdata/api/tests/__init__.py b/trapdata/api/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trapdata/api/tests/test_api.py b/trapdata/api/tests/test_api.py index 3a98dc89..84dba6c7 100644 --- a/trapdata/api/tests/test_api.py +++ b/trapdata/api/tests/test_api.py @@ -1,13 +1,11 @@ import logging import pathlib -from typing import Type from unittest import TestCase from fastapi.testclient import TestClient from trapdata.api.api import ( CLASSIFIER_CHOICES, - APIMothClassifier, PipelineChoice, PipelineRequest, PipelineResponse, @@ -15,8 +13,9 @@ make_algorithm_response, make_pipeline_config_response, ) -from trapdata.api.schemas import PipelineConfigRequest, SourceImageRequest +from trapdata.api.schemas import PipelineConfigRequest from trapdata.api.tests.image_server import StaticFileTestServer +from trapdata.api.tests.utils import get_test_images, get_pipeline_class from trapdata.tests import TEST_IMAGES_BASE_PATH logging.basicConfig(level=logging.INFO) @@ -40,22 +39,10 @@ def tearDownClass(cls): cls.file_server.stop() def get_test_images(self, subdir: str = "vermont", num: int = 2): - images_dir = self.test_images_dir / subdir - source_image_urls = [ - self.file_server.get_url(f.relative_to(self.test_images_dir)) - for f in images_dir.glob("*.jpg") - ][:num] - source_images = [ - SourceImageRequest(id=str(i), url=url) - for i, url in enumerate(source_image_urls) - ] - return source_images + return get_test_images(self.file_server, self.test_images_dir, subdir, num) - def get_test_pipeline( - self, slug: str = "quebec_vermont_moths_2023" - ) -> Type[APIMothClassifier]: - pipeline = CLASSIFIER_CHOICES[slug] - return pipeline + def get_test_pipeline(self, slug: str = "quebec_vermont_moths_2023"): + return get_pipeline_class(slug) def test_pipeline_request(self): """ diff --git a/trapdata/api/tests/utils.py b/trapdata/api/tests/utils.py new file mode 100644 index 00000000..eda17bc5 --- /dev/null +++ b/trapdata/api/tests/utils.py @@ -0,0 +1,124 @@ +"""Shared test utilities for API tests.""" + +from contextlib import contextmanager +from pathlib import Path +from typing import Type +from unittest.mock import patch + +from fastapi.testclient import TestClient + +from trapdata.api.api import CLASSIFIER_CHOICES, APIMothClassifier +from trapdata.api.schemas import SourceImageRequest +from trapdata.api.tests.image_server import StaticFileTestServer + + +def get_test_image_urls( + file_server: StaticFileTestServer, + test_images_dir: Path, + subdir: str = "vermont", + num: int = 2, +) -> list[str]: + """Get list of test image URLs from file server. + + Args: + file_server: StaticFileTestServer instance + test_images_dir: Base directory containing test images + subdir: Subdirectory within test_images_dir (default: "vermont") + num: Number of images to return (default: 2) + + Returns: + List of image URLs from the file server + """ + images_dir = test_images_dir / subdir + source_image_urls = [ + file_server.get_url(f.relative_to(test_images_dir)) + for f in images_dir.glob("*.jpg") + ][:num] + return source_image_urls + + +def get_test_images( + file_server: StaticFileTestServer, + test_images_dir: Path, + subdir: str = "vermont", + num: int = 2, +) -> list[SourceImageRequest]: + """Get list of SourceImageRequest objects for testing. + + Args: + file_server: StaticFileTestServer instance + test_images_dir: Base directory containing test images + subdir: Subdirectory within test_images_dir (default: "vermont") + num: Number of images to return (default: 2) + + Returns: + List of SourceImageRequest objects with IDs and URLs + """ + urls = get_test_image_urls(file_server, test_images_dir, subdir, num) + source_images = [ + SourceImageRequest(id=str(i), url=url) for i, url in enumerate(urls) + ] + return source_images + + +def get_pipeline_class( + slug: str = "quebec_vermont_moths_2023", +) -> Type[APIMothClassifier]: + """Get classifier class by pipeline slug. + + Args: + slug: Pipeline slug (default: "quebec_vermont_moths_2023") + + Returns: + APIMothClassifier class for the specified pipeline + """ + return CLASSIFIER_CHOICES[slug] + + +@contextmanager +def patch_antenna_api_requests(test_client: TestClient): + """Patch requests.get/post to route through TestClient. + + This allows tests to mock the Antenna API by routing requests through + a TestClient instead of making real HTTP calls. Only requests to + http://testserver are mocked - other requests pass through normally. + + Args: + test_client: FastAPI TestClient to route requests through + + Usage: + with patch_antenna_api_requests(antenna_client): + # Code that makes requests to Antenna API + response = requests.get("http://testserver/api/v2/jobs") + """ + import requests + + # Save original methods BEFORE patching + original_session_get = requests.Session.get + original_session_post = requests.Session.post + + def mock_session_get(self, url, **kwargs): + """Mock Session.get - route testserver through TestClient, others pass through.""" + if "testserver" in url: + path = url.replace("http://testserver", "") + headers = kwargs.get("headers", {}) + params = kwargs.get("params", {}) + return test_client.get(path, headers=headers, params=params) + else: + # Let real HTTP requests through (e.g., to file server) + return original_session_get(self, url, **kwargs) + + def mock_session_post(self, url, **kwargs): + """Mock Session.post - route testserver through TestClient, others pass through.""" + if "testserver" in url: + path = url.replace("http://testserver", "") + headers = kwargs.get("headers", {}) + json_data = kwargs.get("json") + return test_client.post(path, headers=headers, json=json_data) + else: + return original_session_post(self, url, **kwargs) + + # Patch Session methods (used by get_http_session) + with patch.object(requests.Session, "get", mock_session_get): + with patch.object(requests.Session, "post", mock_session_post): + yield diff --git a/trapdata/api/utils.py b/trapdata/api/utils.py index 2c2678a4..3d8e02a7 100644 --- a/trapdata/api/utils.py +++ b/trapdata/api/utils.py @@ -2,6 +2,9 @@ import time import PIL.Image +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry from ..common.utils import slugify from .schemas import BoundingBox, SourceImage @@ -33,3 +36,40 @@ def get_crop_fname(source_image: SourceImage, bbox: BoundingBox) -> str: bbox_name = bbox.to_path() timestamp = int(time.time()) # @TODO use pipeline name/version instead return f"{source_name}/{bbox_name}-{timestamp}.jpg" + + +def get_http_session(auth_token: str | None = None) -> requests.Session: + """ + Create an HTTP session with retry logic for transient failures. + + Configures a requests.Session with HTTPAdapter and urllib3.Retry to automatically + retry failed requests with exponential backoff. Only retries on server errors (5XX) + and network failures, NOT on client errors (4XX). Only GET requests are retried. + + TODO: This will likely become part of an AntennaClient class that encapsulates + base_url, auth_token, and session management. See docs/claude/planning/antenna-client.md + + Args: + auth_token: Optional API token. If provided, adds "Token {auth_token}" header. + + Returns: + Configured requests.Session with retry adapter mounted + """ + session = requests.Session() + + retry_strategy = Retry( + total=3, + backoff_factor=0.5, + status_forcelist=(500, 502, 503, 504), + allowed_methods=["GET"], + raise_on_status=False, + ) + + adapter = HTTPAdapter(max_retries=retry_strategy) + session.mount("http://", adapter) + session.mount("https://", adapter) + + if auth_token: + session.headers["Authorization"] = f"Token {auth_token}" + + return session diff --git a/trapdata/cli/base.py b/trapdata/cli/base.py index f53cb651..59c69e8a 100644 --- a/trapdata/cli/base.py +++ b/trapdata/cli/base.py @@ -1,15 +1,17 @@ import pathlib -from typing import Optional +from typing import Annotated, Optional import typer -from trapdata.cli import db, export, queue, settings, shell, show, test +from trapdata.api.api import CLASSIFIER_CHOICES +from trapdata.cli import db, export, queue, settings, shell, show, test, worker from trapdata.db.base import get_session_class from trapdata.db.models.events import get_or_create_monitoring_sessions from trapdata.db.models.queue import add_monitoring_session_to_queue from trapdata.ml.pipeline import start_pipeline -cli = typer.Typer(no_args_is_help=True) +# don't display variable values in errors: +cli = typer.Typer(no_args_is_help=True, pretty_exceptions_show_locals=False) cli.add_typer(export.cli, name="export", help="Export data in various formats") cli.add_typer(shell.cli, name="shell", help="Open an interactive shell") cli.add_typer(test.cli, name="test", help="Run tests") @@ -18,6 +20,7 @@ cli.add_typer( queue.cli, name="queue", help="Add and manage images in the processing queue" ) +cli.add_typer(worker.cli, name="worker", help="Antenna worker for remote processing") @cli.command() diff --git a/trapdata/cli/test.py b/trapdata/cli/test.py index 9d5198f5..d52237d6 100644 --- a/trapdata/cli/test.py +++ b/trapdata/cli/test.py @@ -24,7 +24,7 @@ def all(): # return_code = pytest.main(["--doctest-modules", "-v", "."]) # return_code = pytest.main(["-v", "."]) - return_code = subprocess.call(["pytest", "-v", "."]) + return_code = subprocess.call([sys.executable, "-m", "pytest", "-v"]) sys.exit(return_code) diff --git a/trapdata/cli/worker.py b/trapdata/cli/worker.py new file mode 100644 index 00000000..19fb97aa --- /dev/null +++ b/trapdata/cli/worker.py @@ -0,0 +1,80 @@ +"""CLI commands for Antenna worker.""" + +from typing import Annotated + +import typer + +from trapdata.api.api import CLASSIFIER_CHOICES + +cli = typer.Typer(help="Antenna worker commands for remote processing") + + +@cli.callback(invoke_without_command=True) +def run( + ctx: typer.Context, + pipelines: Annotated[ + list[str] | None, + typer.Option( + "--pipeline", + help="Pipeline to use for processing (e.g., moth_binary, panama_moths_2024). Can be specified multiple times. Defaults to all pipelines if not specified." + ), + ] = None, +): + """ + Run the worker to process images from the Antenna API queue. + + Can be invoked as 'ami worker' or 'ami worker run'. + """ + # Only run the worker if no subcommand was invoked + if ctx.invoked_subcommand is not None: + return + + if not pipelines: + pipelines = list(CLASSIFIER_CHOICES.keys()) + + # Validate that each pipeline is in CLASSIFIER_CHOICES + invalid_pipelines = [ + pipeline for pipeline in pipelines if pipeline not in CLASSIFIER_CHOICES.keys() + ] + + if invalid_pipelines: + raise typer.BadParameter( + f"Invalid pipeline(s): {', '.join(invalid_pipelines)}. Must be one of: {', '.join(CLASSIFIER_CHOICES.keys())}" + ) + + from trapdata.antenna.worker import run_worker + + run_worker(pipelines=pipelines) + + +@cli.command("register") +def register( + name: Annotated[ + str, + typer.Argument( + help="Name for the processing service registration (e.g., 'AMI Data Companion on DRAC gpu-03'). " + "Hostname will be added automatically.", + ), + ], + project: Annotated[ + list[int] | None, + typer.Option( + help="Specific project IDs to register pipelines for. " + "If not specified, registers for all accessible projects.", + ), + ] = None, +): + """ + Register available pipelines with the Antenna platform for specified projects. + + This command registers all available pipeline configurations with the Antenna platform + for the specified projects (or all accessible projects if none specified). + + Examples: + ami worker register "AMI Data Companion on DRAC gpu-03" --project 1 --project 2 + ami worker register "My Processing Service" # registers for all accessible projects + """ + from trapdata.antenna.registration import register_pipelines + + project_ids = project if project else [] + register_pipelines(project_ids=project_ids, service_name=name) diff --git a/trapdata/common/tests/__init__.py b/trapdata/common/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trapdata/common/utils.py b/trapdata/common/utils.py index 15d11b2a..5d0813ad 100644 --- a/trapdata/common/utils.py +++ b/trapdata/common/utils.py @@ -1,9 +1,11 @@ import csv import datetime +import functools import pathlib import random import string -from typing import Any, Union +import time +from typing import Any, Callable, Tuple, Union def get_sequential_sample(direction, images, last_sample=None): @@ -119,3 +121,29 @@ def random_color(): color = [random.random() for _ in range(3)] color.append(0.8) # alpha return color + + +def log_time(start: float = 0, msg: str | None = None) -> Tuple[float, Callable]: + """ + Small helper to measure time between calls. + + Returns: elapsed time since the last call, and a partial function to measure from the current call + Usage: + + _, tlog = log_time() + # do something + _, tlog = tlog("Did something") # will log the time taken by 'something' + # do something else + t, tlog = tlog("Did something else") # will log the time taken by 'something else', returned as 't' + """ + from trapdata.common.logs import logger + + end = time.perf_counter() + if start == 0: + dur = 0.0 + else: + dur = end - start + if msg and start > 0: + logger.info(f"{msg}: {dur:.3f}s") + new_start = time.perf_counter() + return dur, functools.partial(log_time, new_start) diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 6db1771e..b5babb01 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -323,7 +323,7 @@ def save_detected_objects( # CRITICAL PERFORMANCE FIX: Batch fetch all previous images at once # This eliminates the N+1 query problem where previous_image was called for each detection - all_image_ids = [img.id for img in images] + _ = [img.id for img in images] # Create a mapping of image_id to previous_image_id image_to_previous = {} @@ -404,9 +404,7 @@ def save_classified_objects(db_path, object_ids, classified_objects_data): # Use a single session for all operations with db.get_session(db_path) as sesh: # Batch fetch all objects at once - objects = ( - sesh.query(DetectedObject).filter(DetectedObject.id.in_(object_ids)).all() - ) + _ = sesh.query(DetectedObject).filter(DetectedObject.id.in_(object_ids)).all() timestamp = datetime.datetime.now() update_data = [] diff --git a/trapdata/ml/models/base.py b/trapdata/ml/models/base.py index 09c32be1..bb7d1fa6 100644 --- a/trapdata/ml/models/base.py +++ b/trapdata/ml/models/base.py @@ -298,7 +298,8 @@ def save_results( @torch.no_grad() def run(self): - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() logger.info(f"Running inference ({self.name})\n\n") num_batches_total = ceil(len(self.dataloader) / self.batch_size) for i, batch in enumerate(self.dataloader): diff --git a/trapdata/settings.py b/trapdata/settings.py index 6ce566ed..f4b83f16 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -37,6 +37,11 @@ class Settings(BaseSettings): classification_batch_size: int = 20 num_workers: int = 1 + # Antenna API worker settings + antenna_api_base_url: str = "http://localhost:8000/api/v2" + antenna_api_auth_token: str = "" + antenna_api_batch_size: int = 4 + @pydantic.field_validator("image_base_path", "user_data_path") def validate_path(cls, v): """ @@ -143,6 +148,24 @@ class Config: "kivy_type": "numeric", "kivy_section": "performance", }, + "antenna_api_base_url": { + "title": "Antenna API Base URL", + "description": "URL to the Antenna platform API for worker processing (should include /api/v2)", + "kivy_type": "string", + "kivy_section": "antenna", + }, + "antenna_api_auth_token": { + "title": "Antenna API Token", + "description": "Authentication token for your Antenna project", + "kivy_type": "string", + "kivy_section": "antenna", + }, + "antenna_api_batch_size": { + "title": "Antenna API Batch Size", + "description": "Number of tasks to fetch from Antenna per batch", + "kivy_type": "numeric", + "kivy_section": "antenna", + }, } @classmethod