Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions trapdata/antenna/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from trapdata.antenna.schemas import (
AntennaJobsListResponse,
AntennaResultPostResponse,
AntennaTaskResult,
AntennaTaskResults,
JobDispatchMode,
)
from trapdata.api.utils import get_http_session
Expand Down Expand Up @@ -93,14 +95,18 @@ def post_batch_results(
True if successful, False otherwise
"""
url = f"{base_url.rstrip('/')}/jobs/{job_id}/result/"
payload = [r.model_dump(mode="json") for r in results]
payload = AntennaTaskResults(results=results)

with get_http_session(auth_token) as session:
try:
params = {"processing_service_name": processing_service_name}
response = session.post(url, json=payload, params=params, timeout=60)
response = session.post(
url, json=payload.model_dump(mode="json"), timeout=60
)
response.raise_for_status()
logger.debug(f"Successfully posted {len(results)} results to {url}")
result = AntennaResultPostResponse.model_validate(response.json())
logger.debug(
f"Posted {len(results)} results to job {job_id}: {result.results_queued} queued"
)
return True
except requests.RequestException as e:
logger.error(f"Failed to post results to {url}: {e}")
Expand Down
18 changes: 9 additions & 9 deletions trapdata/antenna/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
├──────────────────────────────────────────────────────────────────┤
│ DataLoader workers (num_workers subprocesses) │
│ Each subprocess runs its own RESTDataset.__iter__ loop: │
│ 1. GET /tasks → fetch batch of task metadata from Antenna │
│ 1. POST /tasks → fetch batch of task metadata from Antenna │
│ 2. Download images (threaded, see below) │
│ 3. Yield individual (image_tensor, metadata) rows │
│ The DataLoader collates rows into GPU-sized batches. │
Expand Down Expand Up @@ -76,6 +76,7 @@
from trapdata.antenna.schemas import (
AntennaPipelineProcessingTask,
AntennaTasksListResponse,
AntennaTasksRequest,
)
from trapdata.api.utils import get_http_session
from trapdata.common.logs import logger
Expand All @@ -97,8 +98,8 @@ class RESTDataset(torch.utils.data.IterableDataset):
independently fetches different tasks from the shared queue.

With DataLoader num_workers > 0 (I/O subprocesses, not AMI instances):
Subprocess 1: GET /tasks → receives [1,2,3,4], removed from queue
Subprocess 2: GET /tasks → receives [5,6,7,8], removed from queue
Subprocess 1: POST /tasks → receives [1,2,3,4], removed from queue
Subprocess 2: POST /tasks → receives [5,6,7,8], removed from queue
No duplicates, safe for parallel processing
"""

Expand Down Expand Up @@ -170,15 +171,14 @@ def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]:
Raises:
requests.RequestException: If the request fails (network error, etc.)
"""
url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks"
params = {
"batch": self.batch_size,
"processing_service_name": self.processing_service_name,
}
url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks/"
request_body = AntennaTasksRequest(batch_size=self.batch_size)

self._ensure_sessions()
assert self._api_session is not None
response = self._api_session.get(url, params=params, timeout=30)
response = self._api_session.post(
url, json=request_body.model_dump(), timeout=30
)
response.raise_for_status()

# Parse and validate response with Pydantic
Expand Down
25 changes: 24 additions & 1 deletion trapdata/antenna/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,14 @@ class AntennaJobsListResponse(pydantic.BaseModel):
results: list[AntennaJobListItem]


class AntennaTasksRequest(pydantic.BaseModel):
"""Request body for POST /api/v2/jobs/{job_id}/tasks/."""

batch_size: int


class AntennaTasksListResponse(pydantic.BaseModel):
"""Response from Antenna API GET /api/v2/jobs/{job_id}/tasks."""
"""Response from Antenna API POST /api/v2/jobs/{job_id}/tasks/."""

tasks: list[AntennaPipelineProcessingTask]

Expand All @@ -60,6 +66,23 @@ class AntennaTaskResults(pydantic.BaseModel):
results: list[AntennaTaskResult] = pydantic.Field(default_factory=list)


class QueuedTaskAcknowledgment(pydantic.BaseModel):
"""Acknowledgment for a single result queued for background processing."""

reply_subject: str
status: str
task_id: str


class AntennaResultPostResponse(pydantic.BaseModel):
"""Response from POST /api/v2/jobs/{job_id}/result/."""

status: str
job_id: int
results_queued: int
tasks: list[QueuedTaskAcknowledgment] = pydantic.Field(default_factory=list)


class AsyncPipelineRegistrationRequest(pydantic.BaseModel):
"""
Request to register pipelines from an async processing service
Expand Down
32 changes: 18 additions & 14 deletions trapdata/antenna/tests/antenna_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
AntennaJobListItem,
AntennaJobsListResponse,
AntennaPipelineProcessingTask,
AntennaResultPostResponse,
AntennaTaskResult,
AntennaTaskResults,
AntennaTasksListResponse,
AntennaTasksRequest,
AsyncPipelineRegistrationRequest,
AsyncPipelineRegistrationResponse,
)
Expand Down Expand Up @@ -62,47 +65,48 @@ def get_jobs(
return AntennaJobsListResponse(results=results)


@app.get("/api/v2/jobs/{job_id}/tasks")
def get_tasks(job_id: int, batch: int):
@app.post("/api/v2/jobs/{job_id}/tasks")
def get_tasks(job_id: int, payload: AntennaTasksRequest):
"""Return batch of tasks (atomically remove from queue).

Args:
job_id: Job ID to fetch tasks for
batch: Number of tasks to return
payload: Request body with batch_size

Returns:
AntennaTasksListResponse with batch of tasks
"""
if job_id not in _jobs_queue:
return AntennaTasksListResponse(tasks=[])

# Get up to `batch` tasks and remove them from queue
tasks = _jobs_queue[job_id][:batch]
_jobs_queue[job_id] = _jobs_queue[job_id][batch:]
# Get up to `batch_size` tasks and remove them from queue
tasks = _jobs_queue[job_id][: payload.batch_size]
_jobs_queue[job_id] = _jobs_queue[job_id][payload.batch_size :]

return AntennaTasksListResponse(tasks=tasks)


@app.post("/api/v2/jobs/{job_id}/result/")
def post_results(job_id: int, payload: list[dict]):
def post_results(job_id: int, payload: AntennaTaskResults) -> AntennaResultPostResponse:
"""Store posted results for test validation.

Args:
job_id: Job ID to post results for
payload: List of AntennaTaskResult dicts
payload: Validated batch of task results

Returns:
Success status
AntennaResultPostResponse acknowledgment
"""
if job_id not in _posted_results:
_posted_results[job_id] = []

# Parse each result dict into AntennaTaskResult
for result_dict in payload:
task_result = AntennaTaskResult(**result_dict)
_posted_results[job_id].append(task_result)
_posted_results[job_id].extend(payload.results)

return {"status": "ok"}
return AntennaResultPostResponse(
status="accepted",
job_id=job_id,
results_queued=len(payload.results),
)


@app.get("/api/v2/projects/")
Expand Down
Loading