diff --git a/trapdata/antenna/benchmark.py b/trapdata/antenna/benchmark.py index 483391e..4065fec 100644 --- a/trapdata/antenna/benchmark.py +++ b/trapdata/antenna/benchmark.py @@ -53,7 +53,6 @@ def run_benchmark( num_workers: int, batch_size: int, gpu_batch_size: int, - service_name: str, send_acks: bool = True, ) -> None: """Run the benchmark with the specified parameters. @@ -65,7 +64,6 @@ def run_benchmark( num_workers: Number of DataLoader workers batch_size: Batch size for API requests gpu_batch_size: GPU batch size for DataLoader - service_name: Processing service name """ # Create settings object settings = Settings() @@ -81,14 +79,12 @@ def run_benchmark( print(f" API batch size: {batch_size}") print(f" GPU batch size: {gpu_batch_size}") print(f" Num workers: {num_workers}") - print(f" Service name: {service_name}") print() # Create dataloader dataloader = get_rest_dataloader( job_id=job_id, settings=settings, - processing_service_name=service_name, ) # Initialize ResultPoster for sending acknowledgments @@ -141,7 +137,6 @@ def run_benchmark( auth_token=auth_token, job_id=job_id, results=ack_results, - processing_service_name=service_name, ) total_acks_sent += len(ack_results) @@ -164,7 +159,6 @@ def run_benchmark( auth_token=auth_token, job_id=job_id, results=error_results, - processing_service_name=service_name, ) total_acks_sent += len(error_results) @@ -275,12 +269,6 @@ def main() -> int: parser.add_argument( "--gpu-batch-size", type=int, default=16, help="GPU batch size for DataLoader" ) - parser.add_argument( - "--service-name", - type=str, - default="Performance Test", - help="Processing service name", - ) parser.add_argument( "--skip-acks", action="store_false", @@ -303,7 +291,6 @@ def main() -> int: num_workers=args.num_workers, batch_size=args.batch_size, gpu_batch_size=args.gpu_batch_size, - service_name=args.service_name, send_acks=args.skip_acks, ) return 0 diff --git a/trapdata/antenna/client.py b/trapdata/antenna/client.py index 0bdbfc9..5e8cde6 100644 --- a/trapdata/antenna/client.py +++ b/trapdata/antenna/client.py @@ -6,7 +6,9 @@ from trapdata.antenna.schemas import ( AntennaJobsListResponse, + AntennaResultPostResponse, AntennaTaskResult, + AntennaTaskResults, JobDispatchMode, ) from trapdata.api.utils import get_http_session @@ -30,17 +32,15 @@ def get_jobs( base_url: str, auth_token: str, pipeline_slugs: list[str], - processing_service_name: str, ) -> list[tuple[int, str]]: """Fetch job ids from the API for the given pipelines in a single request. - Calls: GET {base_url}/jobs?pipeline__slug__in=&ids_only=1&processing_service_name= + Calls: GET {base_url}/jobs?pipeline__slug__in=&ids_only=1 Args: base_url: Antenna API base URL (e.g., "http://localhost:8000/api/v2") auth_token: API authentication token pipeline_slugs: List of pipeline slugs to filter jobs - processing_service_name: Name of the processing service Returns: List of (job_id, pipeline_slug) tuples (possibly empty) on success or error. @@ -54,7 +54,6 @@ def get_jobs( "pipeline__slug__in": ",".join(pipeline_slugs), "ids_only": 1, "incomplete_only": 1, - "processing_service_name": processing_service_name, "dispatch_mode": JobDispatchMode.ASYNC_API, # Only fetch async_api jobs } @@ -77,7 +76,6 @@ def post_batch_results( auth_token: str, job_id: int, results: list[AntennaTaskResult], - processing_service_name: str, ) -> bool: """ Post batch results back to the API. @@ -87,20 +85,23 @@ def post_batch_results( auth_token: API authentication token job_id: Job ID results: List of AntennaTaskResult objects - processing_service_name: Name of the processing service Returns: True if successful, False otherwise """ url = f"{base_url.rstrip('/')}/jobs/{job_id}/result/" - payload = [r.model_dump(mode="json") for r in results] + payload = AntennaTaskResults(results=results) with get_http_session(auth_token) as session: try: - params = {"processing_service_name": processing_service_name} - response = session.post(url, json=payload, params=params, timeout=60) + response = session.post( + url, json=payload.model_dump(mode="json"), timeout=60 + ) response.raise_for_status() - logger.debug(f"Successfully posted {len(results)} results to {url}") + result = AntennaResultPostResponse.model_validate(response.json()) + logger.debug( + f"Posted {len(results)} results to job {job_id}: {result.results_queued} queued" + ) return True except requests.RequestException as e: logger.error(f"Failed to post results to {url}: {e}") diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 7602e1f..7ecc7bd 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -15,7 +15,7 @@ ├──────────────────────────────────────────────────────────────────┤ │ DataLoader workers (num_workers subprocesses) │ │ Each subprocess runs its own RESTDataset.__iter__ loop: │ - │ 1. GET /tasks → fetch batch of task metadata from Antenna │ + │ 1. POST /tasks → fetch batch of task metadata from Antenna │ │ 2. Download images (threaded, see below) │ │ 3. Yield individual (image_tensor, metadata) rows │ │ The DataLoader collates rows into GPU-sized batches. │ @@ -76,6 +76,7 @@ from trapdata.antenna.schemas import ( AntennaPipelineProcessingTask, AntennaTasksListResponse, + AntennaTasksRequest, ) from trapdata.api.utils import get_http_session from trapdata.common.logs import logger @@ -97,8 +98,8 @@ class RESTDataset(torch.utils.data.IterableDataset): independently fetches different tasks from the shared queue. With DataLoader num_workers > 0 (I/O subprocesses, not AMI instances): - Subprocess 1: GET /tasks → receives [1,2,3,4], removed from queue - Subprocess 2: GET /tasks → receives [5,6,7,8], removed from queue + Subprocess 1: POST /tasks → receives [1,2,3,4], removed from queue + Subprocess 2: POST /tasks → receives [5,6,7,8], removed from queue No duplicates, safe for parallel processing """ @@ -109,7 +110,6 @@ def __init__( job_id: int, batch_size: int = 1, image_transforms: torchvision.transforms.Compose | None = None, - processing_service_name: str = "", ): """ Initialize the REST dataset. @@ -120,7 +120,6 @@ def __init__( job_id: The job ID to fetch tasks for batch_size: Number of tasks to request per batch image_transforms: Optional transforms to apply to loaded images - processing_service_name: Name of the processing service """ super().__init__() self.base_url = base_url @@ -128,7 +127,6 @@ def __init__( self.job_id = job_id self.batch_size = batch_size self.image_transforms = image_transforms or torchvision.transforms.ToTensor() - self.processing_service_name = processing_service_name # These are created lazily in _ensure_sessions() because they contain # unpicklable objects (ThreadPoolExecutor has a SimpleQueue) and @@ -170,15 +168,14 @@ def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]: Raises: requests.RequestException: If the request fails (network error, etc.) """ - url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks" - params = { - "batch": self.batch_size, - "processing_service_name": self.processing_service_name, - } + url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks/" + request_body = AntennaTasksRequest(batch_size=self.batch_size) self._ensure_sessions() assert self._api_session is not None - response = self._api_session.get(url, params=params, timeout=30) + response = self._api_session.post( + url, json=request_body.model_dump(), timeout=30 + ) response.raise_for_status() # Parse and validate response with Pydantic @@ -410,7 +407,6 @@ def _no_op_collate_fn(batch: list[dict]) -> dict: def get_rest_dataloader( job_id: int, settings: "Settings", - processing_service_name: str, ) -> torch.utils.data.DataLoader: """Create a DataLoader that fetches tasks from Antenna API. @@ -427,14 +423,12 @@ def get_rest_dataloader( - antenna_api_base_url / antenna_api_auth_token - antenna_api_batch_size (tasks per API call and GPU batch size) - num_workers (DataLoader subprocesses) - - processing_service_name (name of this worker) """ dataset = RESTDataset( base_url=settings.antenna_api_base_url, auth_token=settings.antenna_api_auth_token, job_id=job_id, batch_size=settings.antenna_api_batch_size, - processing_service_name=processing_service_name, ) return torch.utils.data.DataLoader( diff --git a/trapdata/antenna/result_posting.py b/trapdata/antenna/result_posting.py index 16207ff..cd76737 100644 --- a/trapdata/antenna/result_posting.py +++ b/trapdata/antenna/result_posting.py @@ -14,7 +14,7 @@ Usage: poster = ResultPoster(max_pending=5) - poster.post_async(base_url, auth_token, job_id, results, service_name) + poster.post_async(base_url, auth_token, job_id, results) metrics = poster.get_metrics() poster.shutdown() """ @@ -23,7 +23,6 @@ import time from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait from dataclasses import dataclass -from typing import Optional from trapdata.antenna.client import post_batch_results from trapdata.common.logs import logger @@ -61,7 +60,7 @@ class ResultPoster: Example: poster = ResultPoster(max_pending=10) - poster.post_async(base_url, auth_token, job_id, results, service_name) + poster.post_async(base_url, auth_token, job_id, results) metrics = poster.get_metrics() poster.shutdown() """ @@ -86,7 +85,6 @@ def post_async( auth_token: str, job_id: int, results: list, - processing_service_name: str, ) -> None: """Post results asynchronously with backpressure control. @@ -98,7 +96,6 @@ def post_async( auth_token: API authentication token job_id: Job ID for the results results: List of result objects to post - processing_service_name: Name of the processing service """ # Clean up completed futures and update metrics self._cleanup_completed_futures() @@ -140,7 +137,6 @@ def post_async( auth_token, job_id, results, - processing_service_name, start_time, ) self.pending_futures.append(future) @@ -156,7 +152,6 @@ def _post_with_timing( auth_token: str, job_id: int, results: list, - processing_service_name: str, start_time: float, ) -> bool: """Internal method that times the post operation and updates metrics. @@ -166,16 +161,13 @@ def _post_with_timing( auth_token: API authentication token job_id: Job ID for the results results: List of result objects to post - processing_service_name: Name of the processing service start_time: Timestamp when the post was initiated Returns: True if successful, False otherwise """ try: - success = post_batch_results( - base_url, auth_token, job_id, results, processing_service_name - ) + success = post_batch_results(base_url, auth_token, job_id, results) elapsed_time = time.time() - start_time with self._metrics_lock: diff --git a/trapdata/antenna/schemas.py b/trapdata/antenna/schemas.py index 32eba89..be64eef 100644 --- a/trapdata/antenna/schemas.py +++ b/trapdata/antenna/schemas.py @@ -34,8 +34,14 @@ class AntennaJobsListResponse(pydantic.BaseModel): results: list[AntennaJobListItem] +class AntennaTasksRequest(pydantic.BaseModel): + """Request body for POST /api/v2/jobs/{job_id}/tasks/.""" + + batch_size: int = pydantic.Field(gt=0) + + class AntennaTasksListResponse(pydantic.BaseModel): - """Response from Antenna API GET /api/v2/jobs/{job_id}/tasks.""" + """Response from Antenna API POST /api/v2/jobs/{job_id}/tasks/.""" tasks: list[AntennaPipelineProcessingTask] @@ -60,6 +66,23 @@ class AntennaTaskResults(pydantic.BaseModel): results: list[AntennaTaskResult] = pydantic.Field(default_factory=list) +class QueuedTaskAcknowledgment(pydantic.BaseModel): + """Acknowledgment for a single result queued for background processing.""" + + reply_subject: str + status: str + task_id: str + + +class AntennaResultPostResponse(pydantic.BaseModel): + """Response from POST /api/v2/jobs/{job_id}/result/.""" + + status: str + job_id: int + results_queued: int + tasks: list[QueuedTaskAcknowledgment] = pydantic.Field(default_factory=list) + + class AsyncPipelineRegistrationRequest(pydantic.BaseModel): """ Request to register pipelines from an async processing service diff --git a/trapdata/antenna/tests/antenna_api_server.py b/trapdata/antenna/tests/antenna_api_server.py index 3731835..9a5f93b 100644 --- a/trapdata/antenna/tests/antenna_api_server.py +++ b/trapdata/antenna/tests/antenna_api_server.py @@ -11,20 +11,22 @@ AntennaJobListItem, AntennaJobsListResponse, AntennaPipelineProcessingTask, + AntennaResultPostResponse, AntennaTaskResult, + AntennaTaskResults, AntennaTasksListResponse, + AntennaTasksRequest, AsyncPipelineRegistrationRequest, AsyncPipelineRegistrationResponse, ) -app = FastAPI() +app = FastAPI(redirect_slashes=False) # State management for tests _jobs_queue: dict[int, list[AntennaPipelineProcessingTask]] = {} _posted_results: dict[int, list[AntennaTaskResult]] = {} _projects: list[dict] = [] _registered_pipelines: dict[int, list[str]] = {} # project_id -> pipeline slugs -_last_get_jobs_service_name: str = "" @app.get("/api/v2/jobs") @@ -32,7 +34,6 @@ def get_jobs( pipeline__slug__in: str = "", ids_only: int = 1, incomplete_only: int = 1, - processing_service_name: str = "", ): """Return available job IDs. @@ -40,14 +41,10 @@ def get_jobs( pipeline__slug__in: Comma-separated pipeline slugs filter ids_only: If 1, return only job IDs incomplete_only: If 1, return only incomplete jobs - processing_service_name: Name of the processing service making the request Returns: AntennaJobsListResponse with list of job IDs """ - global _last_get_jobs_service_name - _last_get_jobs_service_name = processing_service_name - # Determine pipeline slug for response (use first slug from filter) slugs = ( [s for s in pipeline__slug__in.split(",") if s] if pipeline__slug__in else [] @@ -62,13 +59,13 @@ def get_jobs( return AntennaJobsListResponse(results=results) -@app.get("/api/v2/jobs/{job_id}/tasks") -def get_tasks(job_id: int, batch: int): +@app.post("/api/v2/jobs/{job_id}/tasks/") +def get_tasks(job_id: int, payload: AntennaTasksRequest): """Return batch of tasks (atomically remove from queue). Args: job_id: Job ID to fetch tasks for - batch: Number of tasks to return + payload: Request body with batch_size Returns: AntennaTasksListResponse with batch of tasks @@ -76,33 +73,34 @@ def get_tasks(job_id: int, batch: int): if job_id not in _jobs_queue: return AntennaTasksListResponse(tasks=[]) - # Get up to `batch` tasks and remove them from queue - tasks = _jobs_queue[job_id][:batch] - _jobs_queue[job_id] = _jobs_queue[job_id][batch:] + # Get up to `batch_size` tasks and remove them from queue + tasks = _jobs_queue[job_id][: payload.batch_size] + _jobs_queue[job_id] = _jobs_queue[job_id][payload.batch_size :] return AntennaTasksListResponse(tasks=tasks) @app.post("/api/v2/jobs/{job_id}/result/") -def post_results(job_id: int, payload: list[dict]): +def post_results(job_id: int, payload: AntennaTaskResults) -> AntennaResultPostResponse: """Store posted results for test validation. Args: job_id: Job ID to post results for - payload: List of AntennaTaskResult dicts + payload: Validated batch of task results Returns: - Success status + AntennaResultPostResponse acknowledgment """ if job_id not in _posted_results: _posted_results[job_id] = [] - # Parse each result dict into AntennaTaskResult - for result_dict in payload: - task_result = AntennaTaskResult(**result_dict) - _posted_results[job_id].append(task_result) + _posted_results[job_id].extend(payload.results) - return {"status": "ok"} + return AntennaResultPostResponse( + status="accepted", + job_id=job_id, + results_queued=len(payload.results), + ) @app.get("/api/v2/projects/") @@ -198,21 +196,9 @@ def get_registered_pipelines(project_id: int) -> list[str]: return _registered_pipelines.get(project_id, []) -def get_last_get_jobs_service_name() -> str: - """Return the processing_service_name received by the last get_jobs call. - - Returns: - The processing_service_name value from the most recent GET /jobs request, - or an empty string if no request has been made since the last reset(). - """ - return _last_get_jobs_service_name - - def reset(): """Clear all state between tests.""" - global _last_get_jobs_service_name _jobs_queue.clear() _posted_results.clear() _projects.clear() _registered_pipelines.clear() - _last_get_jobs_service_name = "" diff --git a/trapdata/antenna/tests/test_memory_leak.py b/trapdata/antenna/tests/test_memory_leak.py index 84c6930..a09c14c 100644 --- a/trapdata/antenna/tests/test_memory_leak.py +++ b/trapdata/antenna/tests/test_memory_leak.py @@ -39,7 +39,7 @@ def setUpClass(cls): cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH) cls.file_server = StaticFileTestServer(cls.test_images_dir) cls.file_server.start() - cls.antenna_client = TestClient(antenna_app) + cls.antenna_client = TestClient(antenna_app, follow_redirects=False) @classmethod def tearDownClass(cls): @@ -95,7 +95,6 @@ def on_batch(batch_num: int, items: int): "quebec_vermont_moths_2023", 999, self._make_settings(), - processing_service_name="test-service", on_batch_complete=on_batch, ) diff --git a/trapdata/antenna/tests/test_worker.py b/trapdata/antenna/tests/test_worker.py index 7fda2d2..f6b9079 100644 --- a/trapdata/antenna/tests/test_worker.py +++ b/trapdata/antenna/tests/test_worker.py @@ -151,7 +151,7 @@ def setUpClass(cls): cls.file_server.start() # Start server and keep it running for all tests # Setup mock Antenna API - cls.antenna_client = TestClient(antenna_app) + cls.antenna_client = TestClient(antenna_app, follow_redirects=False) @classmethod def tearDownClass(cls): @@ -208,7 +208,7 @@ class TestGetJobsIntegration(TestCase): @classmethod def setUpClass(cls): - cls.antenna_client = TestClient(antenna_app) + cls.antenna_client = TestClient(antenna_app, follow_redirects=False) def setUp(self): antenna_api_server.reset() @@ -225,11 +225,9 @@ def test_returns_job_ids(self): "http://testserver/api/v2", "test-token", ["moths_2024"], - "Test Worker", ) assert [job_id for job_id, _ in result] == [10, 20, 30] - assert antenna_api_server.get_last_get_jobs_service_name() == "Test Worker" # --------------------------------------------------------------------------- @@ -245,7 +243,7 @@ def setUpClass(cls): cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH) cls.file_server = StaticFileTestServer(cls.test_images_dir) cls.file_server.start() # Start server and keep it running for all tests - cls.antenna_client = TestClient(antenna_app) + cls.antenna_client = TestClient(antenna_app, follow_redirects=False) @classmethod def tearDownClass(cls): @@ -273,7 +271,6 @@ def test_empty_queue(self): "quebec_vermont_moths_2023", 100, self._make_settings(), - "Test Service", device=torch.device("cpu"), ) @@ -302,7 +299,6 @@ def test_processes_batch_with_real_inference(self): "quebec_vermont_moths_2023", 101, self._make_settings(), - "Test Service", device=torch.device("cpu"), ) @@ -342,7 +338,6 @@ def test_handles_failed_items(self): "quebec_vermont_moths_2023", 102, self._make_settings(), - "Test Service", device=torch.device("cpu"), ) @@ -379,7 +374,6 @@ def test_mixed_batch_success_and_failures(self): "quebec_vermont_moths_2023", 103, self._make_settings(), - "Test Service", device=torch.device("cpu"), ) @@ -411,7 +405,7 @@ def setUpClass(cls): cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH) cls.file_server = StaticFileTestServer(cls.test_images_dir) cls.file_server.start() # Start server and keep it running for all tests - cls.antenna_client = TestClient(antenna_app) + cls.antenna_client = TestClient(antenna_app, follow_redirects=False) @classmethod def tearDownClass(cls): @@ -473,18 +467,15 @@ def test_full_workflow_with_real_inference(self): "http://testserver/api/v2", "test-token", [pipeline_slug], - "Test Worker", ) job_ids = [job_id for job_id, _ in jobs] assert 200 in job_ids - assert antenna_api_server.get_last_get_jobs_service_name() == "Test Worker" # Step 3: Process job result = _process_job( pipeline_slug, 200, self._make_settings(), - "Test Worker", device=torch.device("cpu"), ) assert result is True @@ -536,7 +527,6 @@ def test_multiple_batches_processed(self): "quebec_vermont_moths_2023", 201, self._make_settings(), - "Test Service", device=torch.device("cpu"), ) diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index ea5f3a3..2b7e1db 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -102,7 +102,6 @@ def _worker_loop(gpu_id: int, pipelines: list[str]): base_url=settings.antenna_api_base_url, auth_token=settings.antenna_api_auth_token, pipeline_slugs=pipelines, - processing_service_name=full_service_name, ) for job_id, pipeline in jobs: logger.info( @@ -113,7 +112,6 @@ def _worker_loop(gpu_id: int, pipelines: list[str]): pipeline=pipeline, job_id=job_id, settings=settings, - processing_service_name=full_service_name, device=device, ) any_jobs = any_jobs or any_work_done @@ -403,7 +401,6 @@ def _process_job( pipeline: str, job_id: int, settings: Settings, - processing_service_name: str, device: torch.device | None = None, on_batch_complete: Callable | None = None, ) -> bool: @@ -413,7 +410,6 @@ def _process_job( pipeline: Pipeline name to use for processing (e.g., moth_binary, panama_moths_2024) job_id: Job ID to process settings: Settings object with antenna_api_* configuration - processing_service_name: Name of the processing service device: The device to use for processing. Auto-detected if None. on_batch_complete: Optional callback invoked after each batch, with kwargs batch_num (int) and items (int, cumulative items processed so far). @@ -424,7 +420,6 @@ def _process_job( loader = get_rest_dataloader( job_id=job_id, settings=settings, - processing_service_name=processing_service_name, ) classifier = None detector = None @@ -508,7 +503,6 @@ def _process_job( settings.antenna_api_auth_token, job_id, batch_results, - processing_service_name, ) batch_total, t_total = t_total() logger.info(