Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions trapdata/antenna/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def run_benchmark(
dataloader = get_rest_dataloader(
job_id=job_id,
settings=settings,
processing_service_name=service_name,
)

# Initialize ResultPoster for sending acknowledgments
Expand Down Expand Up @@ -141,7 +140,6 @@ def run_benchmark(
auth_token=auth_token,
job_id=job_id,
results=ack_results,
processing_service_name=service_name,
)
total_acks_sent += len(ack_results)

Expand All @@ -164,7 +162,6 @@ def run_benchmark(
auth_token=auth_token,
job_id=job_id,
results=error_results,
processing_service_name=service_name,
)
total_acks_sent += len(error_results)

Expand Down
16 changes: 10 additions & 6 deletions trapdata/antenna/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from trapdata.antenna.schemas import (
AntennaJobsListResponse,
AntennaResultPostResponse,
AntennaTaskResult,
AntennaTaskResults,
JobDispatchMode,
)
from trapdata.api.utils import get_http_session
Expand Down Expand Up @@ -77,7 +79,6 @@ def post_batch_results(
auth_token: str,
job_id: int,
results: list[AntennaTaskResult],
processing_service_name: str,
) -> bool:
"""
Post batch results back to the API.
Expand All @@ -87,20 +88,23 @@ def post_batch_results(
auth_token: API authentication token
job_id: Job ID
results: List of AntennaTaskResult objects
processing_service_name: Name of the processing service

Returns:
True if successful, False otherwise
"""
url = f"{base_url.rstrip('/')}/jobs/{job_id}/result/"
payload = [r.model_dump(mode="json") for r in results]
payload = AntennaTaskResults(results=results)

with get_http_session(auth_token) as session:
try:
params = {"processing_service_name": processing_service_name}
response = session.post(url, json=payload, params=params, timeout=60)
response = session.post(
url, json=payload.model_dump(mode="json"), timeout=60
)
response.raise_for_status()
logger.debug(f"Successfully posted {len(results)} results to {url}")
result = AntennaResultPostResponse.model_validate(response.json())
logger.debug(
f"Posted {len(results)} results to job {job_id}: {result.results_queued} queued"
)
return True
except requests.RequestException as e:
logger.error(f"Failed to post results to {url}: {e}")
Expand Down
24 changes: 9 additions & 15 deletions trapdata/antenna/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
├──────────────────────────────────────────────────────────────────┤
│ DataLoader workers (num_workers subprocesses) │
│ Each subprocess runs its own RESTDataset.__iter__ loop: │
│ 1. GET /tasks → fetch batch of task metadata from Antenna │
│ 1. POST /tasks → fetch batch of task metadata from Antenna │
│ 2. Download images (threaded, see below) │
│ 3. Yield individual (image_tensor, metadata) rows │
│ The DataLoader collates rows into GPU-sized batches. │
Expand Down Expand Up @@ -76,6 +76,7 @@
from trapdata.antenna.schemas import (
AntennaPipelineProcessingTask,
AntennaTasksListResponse,
AntennaTasksRequest,
)
from trapdata.api.utils import get_http_session
from trapdata.common.logs import logger
Expand All @@ -97,8 +98,8 @@ class RESTDataset(torch.utils.data.IterableDataset):
independently fetches different tasks from the shared queue.

With DataLoader num_workers > 0 (I/O subprocesses, not AMI instances):
Subprocess 1: GET /tasks → receives [1,2,3,4], removed from queue
Subprocess 2: GET /tasks → receives [5,6,7,8], removed from queue
Subprocess 1: POST /tasks → receives [1,2,3,4], removed from queue
Subprocess 2: POST /tasks → receives [5,6,7,8], removed from queue
No duplicates, safe for parallel processing
"""

Expand All @@ -109,7 +110,6 @@ def __init__(
job_id: int,
batch_size: int = 1,
image_transforms: torchvision.transforms.Compose | None = None,
processing_service_name: str = "",
):
"""
Initialize the REST dataset.
Expand All @@ -120,15 +120,13 @@ def __init__(
job_id: The job ID to fetch tasks for
batch_size: Number of tasks to request per batch
image_transforms: Optional transforms to apply to loaded images
processing_service_name: Name of the processing service
"""
super().__init__()
self.base_url = base_url
self.auth_token = auth_token
self.job_id = job_id
self.batch_size = batch_size
self.image_transforms = image_transforms or torchvision.transforms.ToTensor()
self.processing_service_name = processing_service_name

# These are created lazily in _ensure_sessions() because they contain
# unpicklable objects (ThreadPoolExecutor has a SimpleQueue) and
Expand Down Expand Up @@ -170,15 +168,14 @@ def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]:
Raises:
requests.RequestException: If the request fails (network error, etc.)
"""
url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks"
params = {
"batch": self.batch_size,
"processing_service_name": self.processing_service_name,
}
url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks/"
request_body = AntennaTasksRequest(batch_size=self.batch_size)

self._ensure_sessions()
assert self._api_session is not None
response = self._api_session.get(url, params=params, timeout=30)
response = self._api_session.post(
url, json=request_body.model_dump(), timeout=30
)
response.raise_for_status()

# Parse and validate response with Pydantic
Expand Down Expand Up @@ -410,7 +407,6 @@ def _no_op_collate_fn(batch: list[dict]) -> dict:
def get_rest_dataloader(
job_id: int,
settings: "Settings",
processing_service_name: str,
) -> torch.utils.data.DataLoader:
"""Create a DataLoader that fetches tasks from Antenna API.

Expand All @@ -427,14 +423,12 @@ def get_rest_dataloader(
- antenna_api_base_url / antenna_api_auth_token
- antenna_api_batch_size (tasks per API call and GPU batch size)
- num_workers (DataLoader subprocesses)
- processing_service_name (name of this worker)
"""
dataset = RESTDataset(
base_url=settings.antenna_api_base_url,
auth_token=settings.antenna_api_auth_token,
job_id=job_id,
batch_size=settings.antenna_api_batch_size,
processing_service_name=processing_service_name,
)

return torch.utils.data.DataLoader(
Expand Down
10 changes: 1 addition & 9 deletions trapdata/antenna/result_posting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import time
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
from dataclasses import dataclass
from typing import Optional

from trapdata.antenna.client import post_batch_results
from trapdata.common.logs import logger
Expand Down Expand Up @@ -86,7 +85,6 @@ def post_async(
auth_token: str,
job_id: int,
results: list,
processing_service_name: str,
) -> None:
"""Post results asynchronously with backpressure control.

Expand All @@ -98,7 +96,6 @@ def post_async(
auth_token: API authentication token
job_id: Job ID for the results
results: List of result objects to post
processing_service_name: Name of the processing service
"""
# Clean up completed futures and update metrics
self._cleanup_completed_futures()
Expand Down Expand Up @@ -140,7 +137,6 @@ def post_async(
auth_token,
job_id,
results,
processing_service_name,
start_time,
)
self.pending_futures.append(future)
Expand All @@ -156,7 +152,6 @@ def _post_with_timing(
auth_token: str,
job_id: int,
results: list,
processing_service_name: str,
start_time: float,
) -> bool:
"""Internal method that times the post operation and updates metrics.
Expand All @@ -166,16 +161,13 @@ def _post_with_timing(
auth_token: API authentication token
job_id: Job ID for the results
results: List of result objects to post
processing_service_name: Name of the processing service
start_time: Timestamp when the post was initiated

Returns:
True if successful, False otherwise
"""
try:
success = post_batch_results(
base_url, auth_token, job_id, results, processing_service_name
)
success = post_batch_results(base_url, auth_token, job_id, results)
elapsed_time = time.time() - start_time

with self._metrics_lock:
Expand Down
25 changes: 24 additions & 1 deletion trapdata/antenna/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,14 @@ class AntennaJobsListResponse(pydantic.BaseModel):
results: list[AntennaJobListItem]


class AntennaTasksRequest(pydantic.BaseModel):
"""Request body for POST /api/v2/jobs/{job_id}/tasks/."""

batch_size: int = pydantic.Field(gt=0)


class AntennaTasksListResponse(pydantic.BaseModel):
"""Response from Antenna API GET /api/v2/jobs/{job_id}/tasks."""
"""Response from Antenna API POST /api/v2/jobs/{job_id}/tasks/."""

tasks: list[AntennaPipelineProcessingTask]

Expand All @@ -60,6 +66,23 @@ class AntennaTaskResults(pydantic.BaseModel):
results: list[AntennaTaskResult] = pydantic.Field(default_factory=list)


class QueuedTaskAcknowledgment(pydantic.BaseModel):
"""Acknowledgment for a single result queued for background processing."""

reply_subject: str
status: str
task_id: str


class AntennaResultPostResponse(pydantic.BaseModel):
"""Response from POST /api/v2/jobs/{job_id}/result/."""

status: str
job_id: int
results_queued: int
tasks: list[QueuedTaskAcknowledgment] = pydantic.Field(default_factory=list)


class AsyncPipelineRegistrationRequest(pydantic.BaseModel):
"""
Request to register pipelines from an async processing service
Expand Down
32 changes: 18 additions & 14 deletions trapdata/antenna/tests/antenna_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
AntennaJobListItem,
AntennaJobsListResponse,
AntennaPipelineProcessingTask,
AntennaResultPostResponse,
AntennaTaskResult,
AntennaTaskResults,
AntennaTasksListResponse,
AntennaTasksRequest,
AsyncPipelineRegistrationRequest,
AsyncPipelineRegistrationResponse,
)
Expand Down Expand Up @@ -62,47 +65,48 @@ def get_jobs(
return AntennaJobsListResponse(results=results)


@app.get("/api/v2/jobs/{job_id}/tasks")
def get_tasks(job_id: int, batch: int):
@app.post("/api/v2/jobs/{job_id}/tasks")
def get_tasks(job_id: int, payload: AntennaTasksRequest):
"""Return batch of tasks (atomically remove from queue).

Args:
job_id: Job ID to fetch tasks for
batch: Number of tasks to return
payload: Request body with batch_size

Returns:
AntennaTasksListResponse with batch of tasks
"""
if job_id not in _jobs_queue:
return AntennaTasksListResponse(tasks=[])

# Get up to `batch` tasks and remove them from queue
tasks = _jobs_queue[job_id][:batch]
_jobs_queue[job_id] = _jobs_queue[job_id][batch:]
# Get up to `batch_size` tasks and remove them from queue
tasks = _jobs_queue[job_id][: payload.batch_size]
_jobs_queue[job_id] = _jobs_queue[job_id][payload.batch_size :]

return AntennaTasksListResponse(tasks=tasks)


@app.post("/api/v2/jobs/{job_id}/result/")
def post_results(job_id: int, payload: list[dict]):
def post_results(job_id: int, payload: AntennaTaskResults) -> AntennaResultPostResponse:
"""Store posted results for test validation.

Args:
job_id: Job ID to post results for
payload: List of AntennaTaskResult dicts
payload: Validated batch of task results

Returns:
Success status
AntennaResultPostResponse acknowledgment
"""
if job_id not in _posted_results:
_posted_results[job_id] = []

# Parse each result dict into AntennaTaskResult
for result_dict in payload:
task_result = AntennaTaskResult(**result_dict)
_posted_results[job_id].append(task_result)
_posted_results[job_id].extend(payload.results)

return {"status": "ok"}
return AntennaResultPostResponse(
status="accepted",
job_id=job_id,
results_queued=len(payload.results),
)


@app.get("/api/v2/projects/")
Expand Down
1 change: 0 additions & 1 deletion trapdata/antenna/tests/test_memory_leak.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def on_batch(batch_num: int, items: int):
"quebec_vermont_moths_2023",
999,
self._make_settings(),
processing_service_name="test-service",
on_batch_complete=on_batch,
)

Expand Down
5 changes: 0 additions & 5 deletions trapdata/antenna/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def _worker_loop(gpu_id: int, pipelines: list[str]):
pipeline=pipeline,
job_id=job_id,
settings=settings,
processing_service_name=full_service_name,
device=device,
)
any_jobs = any_jobs or any_work_done
Expand Down Expand Up @@ -403,7 +402,6 @@ def _process_job(
pipeline: str,
job_id: int,
settings: Settings,
processing_service_name: str,
device: torch.device | None = None,
on_batch_complete: Callable | None = None,
) -> bool:
Expand All @@ -413,7 +411,6 @@ def _process_job(
pipeline: Pipeline name to use for processing (e.g., moth_binary, panama_moths_2024)
job_id: Job ID to process
settings: Settings object with antenna_api_* configuration
processing_service_name: Name of the processing service
device: The device to use for processing. Auto-detected if None.
on_batch_complete: Optional callback invoked after each batch, with kwargs
batch_num (int) and items (int, cumulative items processed so far).
Expand All @@ -424,7 +421,6 @@ def _process_job(
loader = get_rest_dataloader(
job_id=job_id,
settings=settings,
processing_service_name=processing_service_name,
)
classifier = None
detector = None
Expand Down Expand Up @@ -508,7 +504,6 @@ def _process_job(
settings.antenna_api_auth_token,
job_id,
batch_results,
processing_service_name,
)
batch_total, t_total = t_total()
logger.info(
Expand Down
Loading