Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
9129421
WIP: Pull worker and REST dataset
carlosgjs Oct 15, 2025
41fef93
Clean-up, addd "worker" cli command, move token to env var
carlosgjs Oct 17, 2025
87910aa
Post results back
carlosgjs Oct 17, 2025
c67afce
Progress updates working
carlosgjs Oct 17, 2025
64e188d
clean up
carlosgjs Oct 24, 2025
c00de9d
Better error handling
carlosgjs Nov 4, 2025
3b60538
Support multiple pipelines
carlosgjs Dec 4, 2025
45e68bc
Use app.state for the service info
carlosgjs Dec 5, 2025
3c4dd8c
API launch target
carlosgjs Dec 5, 2025
8f76365
Integration fixes
carlosgjs Dec 9, 2025
bef1cd7
Use PipelineProcessingTask instead of raw dicts
carlosgjs Dec 10, 2025
52cff32
Fix to returned results
carlos-irreverentlabs Dec 12, 2025
f3f3cd6
Trigger CI workflows
mihow Jan 24, 2026
589cd0d
Add Antenna API settings for worker configuration
mihow Jan 24, 2026
c4147bd
Add Pydantic schemas for Antenna API responses
mihow Jan 24, 2026
f7f454a
Refactor worker to use Settings pattern and improve robustness
mihow Jan 24, 2026
7846510
Improve datasets error handling and API contract
mihow Jan 24, 2026
822c436
Add type annotations to update_detection_classification
mihow Jan 24, 2026
2f26e0f
Add Antenna worker documentation
mihow Jan 24, 2026
99e685e
Update poetry.lock with dependency updates
mihow Jan 24, 2026
ab073b3
Replace fragile urljoin with explicit f-string URL construction
mihow Jan 24, 2026
078aa26
Use plural names for batch dict keys containing lists
mihow Jan 24, 2026
38942ee
Merge branch 'main' of https://github.com/RolnickLab/ami-data-manager…
mihow Jan 24, 2026
ce1d754
Fix API tests not running in main test suite
mihow Jan 24, 2026
29172d7
Rename batch result schemas to use Antenna prefix for consistency
mihow Jan 24, 2026
d85bafb
turn off typer show locals
carlosgjs Jan 27, 2026
22c4182
add back help text
carlosgjs Jan 27, 2026
a30ffd5
Flake fixes
carlosgjs Jan 27, 2026
5baab55
Fix REST dataloader to use localization_batch_size for inference batc…
mihow Jan 28, 2026
1bf5ee5
Fix type annotations to use explicit | None syntax
mihow Jan 28, 2026
1a523b2
Retry worker API requests with urllib3 adapter, reuse sessions (#104)
mihow Jan 28, 2026
9bd7142
AMI: Pipeline Registration (#106)
mihow Jan 29, 2026
602b2bc
Address code review feedback
mihow Jan 29, 2026
b1b184c
Disable POST retries by default in get_http_session
mihow Jan 29, 2026
ce3d967
Add validation and error handling improvements
mihow Jan 29, 2026
15d07c4
Remove redundant worker tests
mihow Jan 29, 2026
15da4dd
Refactor: Extract Antenna integration into dedicated module
mihow Jan 30, 2026
3825517
chore: remove temporary plans
mihow Jan 30, 2026
8e9c7fb
Simplify HTTP session config: hardcode retry, pass auth explicitly
mihow Jan 30, 2026
1c5ed89
feat: add example service file for Antenna worker, add comments
mihow Jan 30, 2026
b427ed2
fix: guard torch.cuda.empty_cache() calls with is_available() check
mihow Jan 30, 2026
2cc0259
fix: use sys.executable for pytest subprocess call
mihow Jan 30, 2026
2594bf3
fix: handle post_batch_results failure to prevent silent data loss
mihow Jan 30, 2026
4598278
refactor: make 'ami worker' the default command, use singular --pipel…
mihow Jan 30, 2026
361da2a
fix: handle post_batch_results failure to prevent silent data loss
mihow Jan 30, 2026
c4df11c
chore: remove validate_dwc_export.py (not meant for this PR)
mihow Jan 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
},
{
"name": "Run worker",
"type": "debugpy",
"request": "launch",
"module": "trapdata.cli.base",
"args": ["worker"]
},
{
"name": "Run api",
"type": "debugpy",
"request": "launch",
"module": "trapdata.cli.base",
"args": ["api"]
}
]
}
53 changes: 34 additions & 19 deletions trapdata/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import enum
import time
from contextlib import asynccontextmanager

import fastapi
import pydantic
Expand Down Expand Up @@ -36,7 +37,18 @@
from .schemas import PipelineResultsResponse as PipelineResponse_
from .schemas import ProcessingServiceInfoResponse, SourceImage, SourceImageResponse

app = fastapi.FastAPI()

@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
# cache the service info to be built only once at startup
app.state.service_info = initialize_service_info()
logger.info("Initialized service info")
yield
# Shutdown event: Clean up resources (if necessary)
logger.info("Shutting down API")


app = fastapi.FastAPI(lifespan=lifespan)
app.add_middleware(GZipMiddleware)


Expand Down Expand Up @@ -157,13 +169,6 @@ def make_pipeline_config_response(
)


# @TODO This requires loading all models into memory! Can we avoid this?
Comment thread
mihow marked this conversation as resolved.
PIPELINE_CONFIGS = [
make_pipeline_config_response(classifier_class, slug=key)
for key, classifier_class in CLASSIFIER_CHOICES.items()
]


class PipelineRequest(PipelineRequest_):
pipeline: PipelineChoice = pydantic.Field(
description=PipelineRequest_.model_fields["pipeline"].description,
Expand Down Expand Up @@ -313,17 +318,7 @@ async def process(data: PipelineRequest) -> PipelineResponse:

@app.get("/info", tags=["services"])
async def info() -> ProcessingServiceInfoResponse:
info = ProcessingServiceInfoResponse(
name="Antenna Inference API",
description=(
"The primary endpoint for processing images for the Antenna platform. "
"This API provides access to multiple detection and classification "
"algorithms by multiple labs for processing images of moths."
),
pipelines=PIPELINE_CONFIGS,
# algorithms=list(algorithm_choices.values()),
)
return info
return app.state.service_info


# Check if the server is online
Expand Down Expand Up @@ -361,6 +356,26 @@ async def readyz():
# pass


def initialize_service_info() -> ProcessingServiceInfoResponse:
# @TODO This requires loading all models into memory! Can we avoid this?
Comment thread
mihow marked this conversation as resolved.
pipeline_configs = [
make_pipeline_config_response(classifier_class, slug=key)
for key, classifier_class in CLASSIFIER_CHOICES.items()
]

_info = ProcessingServiceInfoResponse(
name="Antenna Inference API",
description=(
"The primary endpoint for processing images for the Antenna platform. "
"This API provides access to multiple detection and classification "
"algorithms by multiple labs for processing images of moths."
),
pipelines=pipeline_configs,
# algorithms=list(algorithm_choices.values()),
)
return _info


if __name__ == "__main__":
import uvicorn

Expand Down
241 changes: 240 additions & 1 deletion trapdata/api/datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import os
import typing
from io import BytesIO

import requests
import torch
import torch.utils.data
import torchvision
from PIL import Image

from trapdata.common.logs import logger

from .schemas import DetectionResponse, SourceImage
from .schemas import DetectionResponse, PipelineProcessingTask, SourceImage


class LocalizationImageDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -87,3 +91,238 @@ def __getitem__(self, idx):

# return (ids_batch, image_batch)
return (source_image.id, detection_idx), image_data


class RESTDataset(torch.utils.data.IterableDataset):
"""
An IterableDataset that fetches tasks from a REST API endpoint and loads images.

The dataset continuously polls the API for tasks, loads the associated images,
and yields them as PyTorch tensors along with metadata.
"""

def __init__(
self,
base_url: str,
job_id: int,
batch_size: int = 1,
image_transforms: typing.Optional[torchvision.transforms.Compose] = None,
auth_token: typing.Optional[str] = None,
):
"""
Initialize the REST dataset.

Args:
base_url: Base URL for the API (e.g., "http://localhost:8000")
job_id: The job ID to fetch tasks for
batch_size: Number of tasks to request per batch
image_transforms: Optional transforms to apply to loaded images
auth_token: API authentication token. If not provided, reads from
ANTENNA_API_TOKEN environment variable
"""
super().__init__()
self.base_url = base_url.rstrip("/")
self.job_id = job_id
self.batch_size = batch_size
self.image_transforms = image_transforms or torchvision.transforms.ToTensor()
self.auth_token = auth_token or os.environ.get("ANTENNA_API_TOKEN")

def _fetch_tasks(self) -> list[PipelineProcessingTask]:
"""
Fetch a batch of tasks from the REST API.

Returns:
List of task dictionaries from the API response
"""
url = f"{self.base_url}/api/v2/jobs/{self.job_id}/tasks"
params = {"batch": self.batch_size}

headers = {}
if self.auth_token:
headers["Authorization"] = f"Token {self.auth_token}"

try:
response = requests.get(
url,
params=params,
timeout=30,
headers=headers,
)
response.raise_for_status()
data = response.json()
tasks = [PipelineProcessingTask(**task) for task in data.get("tasks", [])]
return tasks
except requests.RequestException as e:
logger.error(f"Failed to fetch tasks from {url}: {e}")
return []

def _load_image(self, image_url: str) -> typing.Optional[torch.Tensor]:
"""
Load an image from a URL and convert it to a PyTorch tensor.

Args:
image_url: URL of the image to load

Returns:
Image as a PyTorch tensor, or None if loading failed
"""
try:
response = requests.get(image_url, timeout=30)
response.raise_for_status()
image = Image.open(BytesIO(response.content))

# Convert to RGB if necessary
if image.mode != "RGB":
image = image.convert("RGB")

# Apply transforms
image_tensor = self.image_transforms(image)
return image_tensor
except Exception as e:
logger.error(f"Failed to load image from {image_url}: {e}")
return None

def __iter__(self):
"""
Iterate over tasks from the REST API.

Yields:
Dictionary containing:
- image: PyTorch tensor of the loaded image
- reply_subject: Reply subject for the task
- batch_index: Index of the image in the batch
- job_id: Job ID
- image_id: Image ID
"""
try:
# Get worker info for debugging
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id if worker_info else 0
num_workers = worker_info.num_workers if worker_info else 1

logger.info(
f"Worker {worker_id}/{num_workers} starting iteration for job {self.job_id}"
)

while True:
tasks = self._fetch_tasks()
# _, t = log_time()
# _, t = t(f"Worker {worker_id}: Fetched {len(tasks)} tasks from API")

# If no tasks returned, dataset is finished
if not tasks:
logger.info(
f"Worker {worker_id}: No more tasks for job {self.job_id}, terminating"
)
break
Comment thread
mihow marked this conversation as resolved.
Outdated

for task in tasks:
errors = []
# Load the image
# _, t = log_time()
image_tensor = (
self._load_image(task.image_url) if task.image_url else None
)
# _, t = t(f"Loaded image from {image_url}")

if image_tensor is None:
errors.append("failed to load image")

if errors:
logger.warning(
f"Worker {worker_id}: Errors in task for image '{task.image_id}': {', '.join(errors)}"
)

# Yield the data row
row = {
"image": image_tensor,
"reply_subject": task.reply_subject,
"image_id": task.image_id,
"image_url": task.image_url,
}
if errors:
row["error"] = ("; ".join(errors) if errors else None,)
Comment thread
mihow marked this conversation as resolved.
Outdated
yield row

logger.info(f"Worker {worker_id}: Iterator finished")
except Exception as e:
logger.error(f"Worker {worker_id}: Exception in iterator: {e}")
raise


def rest_collate_fn(batch: list[dict]) -> dict:
"""
Custom collate function that separates failed and successful items.

Returns a dict with:
- images: List of valid tensors
- reply_subjects: List of reply subjects for valid images
- image_ids: List of image IDs for valid images
- image_urls: List of image URLs for valid images
- failed_items: List of dicts with metadata for failed items
Comment thread
mihow marked this conversation as resolved.
Outdated
"""
successful = []
failed = []

for item in batch:
if item["image"] is None or item.get("error"):
# Failed item
failed.append(
{
"reply_subject": item["reply_subject"],
"image_id": item["image_id"],
"image_url": item.get("image_url"),
"error": item.get("error", "Unknown error"),
}
)
else:
# Successful item
successful.append(item)

# Collate successful items
if successful:
result = {
"image": torch.stack([item["image"] for item in successful]),
"reply_subject": [item["reply_subject"] for item in successful],
"image_id": [item["image_id"] for item in successful],
"image_url": [item.get("image_url") for item in successful],
}
else:
# Empty batch - all failed
result = {
"reply_subject": [],
"image_id": [],
}

result["failed_items"] = failed

return result


def get_rest_dataloader(
job_id: int,
base_url: str = "http://localhost:8000",
batch_size: int = 4,
num_workers: int = 2,
auth_token: typing.Optional[str] = None,
Comment thread
mihow marked this conversation as resolved.
Outdated
) -> torch.utils.data.DataLoader:
"""
Args:
base_url: Base URL for the REST API (default: http://localhost:8000)
job_id: Job id to fetch tasks for (default: 11)
batch_size: Number of tasks/images per batch (default: 4)
num_workers: Number of DataLoader workers (default: 2)
"""
assert base_url is not None, "Base URL must be provided"
base_url = base_url.rstrip("/")

dataset = RESTDataset(
base_url=base_url, job_id=job_id, batch_size=batch_size, auth_token=auth_token
)

return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
collate_fn=rest_collate_fn,
)
Loading
Loading