From 456b0259009dbd02413b2662d6ce20670d17be7d Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 9 Feb 2026 18:27:52 -0800 Subject: [PATCH 01/17] fix: return images as list in rest_collate_fn to support variable sizes torch.stack requires all tensors to be the same size, which crashes when a batch contains images of different resolutions (e.g. 3420x6080 and 2160x4096). FasterRCNN natively accepts a list of variable-sized images, and predict_batch already handles this code path. Co-Authored-By: Claude Opus 4.6 --- trapdata/antenna/datasets.py | 2 +- trapdata/antenna/tests/test_worker.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index faf56b8..22801c9 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -231,7 +231,7 @@ def rest_collate_fn(batch: list[dict]) -> dict: # Collate successful items if successful: result = { - "images": torch.stack([item["image"] for item in successful]), + "images": [item["image"] for item in successful], "reply_subjects": [item["reply_subject"] for item in successful], "image_ids": [item["image_id"] for item in successful], "image_urls": [item.get("image_url") for item in successful], diff --git a/trapdata/antenna/tests/test_worker.py b/trapdata/antenna/tests/test_worker.py index 4a83958..0a3b8dd 100644 --- a/trapdata/antenna/tests/test_worker.py +++ b/trapdata/antenna/tests/test_worker.py @@ -55,7 +55,8 @@ def test_all_successful(self): result = rest_collate_fn(batch) assert "images" in result - assert result["images"].shape == (2, 3, 64, 64) + assert len(result["images"]) == 2 + assert result["images"][0].shape == (3, 64, 64) assert result["image_ids"] == ["img1", "img2"] assert result["reply_subjects"] == ["subj1", "subj2"] assert result["failed_items"] == [] @@ -104,7 +105,8 @@ def test_mixed(self): ] result = rest_collate_fn(batch) - assert result["images"].shape == (1, 3, 64, 64) + assert len(result["images"]) == 1 + assert result["images"][0].shape == (3, 64, 64) assert result["image_ids"] == ["img1"] assert len(result["failed_items"]) == 1 assert result["failed_items"][0]["image_id"] == "img2" From 54d9944d1a04806fce9808717005d6daad4399b7 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 9 Feb 2026 18:28:07 -0800 Subject: [PATCH 02/17] fix: handle batch processing errors per-batch instead of crashing job Wrap the batch processing loop body in try/except so a single failed batch doesn't kill the entire job. On failure, error results are posted back to Antenna for each image in the batch so tasks don't get stuck in the queue. Also downgrade post_batch_results failure from a raised exception to a logged error to avoid losing progress on subsequent batches. Co-Authored-By: Claude Opus 4.6 --- trapdata/antenna/worker.py | 191 ++++++++++++++++++++----------------- 1 file changed, 104 insertions(+), 87 deletions(-) diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index 2fbf3b5..f31e74e 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -121,99 +121,118 @@ def _process_job( reply_subjects = batch.get("reply_subjects", [None] * len(images)) image_urls = batch.get("image_urls", [None] * len(images)) - # Validate all arrays have same length before zipping - if len(image_ids) != len(images): - raise ValueError( - f"Length mismatch: image_ids ({len(image_ids)}) != images ({len(images)})" - ) - if len(image_ids) != len(reply_subjects) or len(image_ids) != len(image_urls): - raise ValueError( - f"Length mismatch: image_ids ({len(image_ids)}), " - f"reply_subjects ({len(reply_subjects)}), image_urls ({len(image_urls)})" - ) + try: + # Validate all arrays have same length before zipping + if len(image_ids) != len(images): + raise ValueError( + f"Length mismatch: image_ids ({len(image_ids)}) != images ({len(images)})" + ) + if len(image_ids) != len(reply_subjects) or len(image_ids) != len( + image_urls + ): + raise ValueError( + f"Length mismatch: image_ids ({len(image_ids)}), " + f"reply_subjects ({len(reply_subjects)}), image_urls ({len(image_urls)})" + ) - # Track start time for this batch - batch_start_time = datetime.datetime.now() + # Track start time for this batch + batch_start_time = datetime.datetime.now() - logger.info(f"Processing batch {i + 1}") - # output is dict of "boxes", "labels", "scores" - batch_output = [] - if len(images) > 0: - batch_output = detector.predict_batch(images) + logger.info(f"Processing batch {i + 1}") + # output is dict of "boxes", "labels", "scores" + batch_output = [] + if len(images) > 0: + batch_output = detector.predict_batch(images) - items += len(batch_output) - logger.info(f"Total items processed so far: {items}") - batch_output = list(detector.post_process_batch(batch_output)) + items += len(batch_output) + logger.info(f"Total items processed so far: {items}") + batch_output = list(detector.post_process_batch(batch_output)) - # Convert image_ids to list if needed - if isinstance(image_ids, (np.ndarray, torch.Tensor)): - image_ids = image_ids.tolist() + # Convert image_ids to list if needed + if isinstance(image_ids, (np.ndarray, torch.Tensor)): + image_ids = image_ids.tolist() - # TODO CGJS: Add seconds per item calculation for both detector and classifier - detector.save_results( - item_ids=image_ids, - batch_output=batch_output, - seconds_per_item=0, - ) - dt, t = t("Finished detection") - total_detection_time += dt - - # Group detections by image_id - image_detections: dict[str, list[DetectionResponse]] = { - img_id: [] for img_id in image_ids - } - image_tensors = dict(zip(image_ids, images, strict=True)) - - classifier.reset(detector.results) - - for idx, dresp in enumerate(detector.results): - image_tensor = image_tensors[dresp.source_image_id] - bbox = dresp.bbox - # crop the image tensor using the bbox - crop = image_tensor[ - :, int(bbox.y1) : int(bbox.y2), int(bbox.x1) : int(bbox.x2) - ] - crop = crop.unsqueeze(0) # add batch dimension - classifier_out = classifier.predict_batch(crop) - classifier_out = classifier.post_process_batch(classifier_out) - detection = classifier.update_detection_classification( + # TODO CGJS: Add seconds per item calculation for both detector and classifier + detector.save_results( + item_ids=image_ids, + batch_output=batch_output, seconds_per_item=0, - image_id=dresp.source_image_id, - detection_idx=idx, - predictions=classifier_out[0], - ) - image_detections[dresp.source_image_id].append(detection) - all_detections.append(detection) - - ct, t = t("Finished classification") - total_classification_time += ct - - # Calculate batch processing time - batch_end_time = datetime.datetime.now() - batch_elapsed = (batch_end_time - batch_start_time).total_seconds() - - # Post results back to the API with PipelineResponse for each image - batch_results: list[AntennaTaskResult] = [] - for reply_subject, image_id, image_url in zip( - reply_subjects, image_ids, image_urls, strict=True - ): - # Create SourceImageResponse for this image - source_image = SourceImageResponse(id=image_id, url=image_url) - - # Create PipelineResultsResponse - pipeline_response = PipelineResultsResponse( - pipeline=pipeline, - source_images=[source_image], - detections=image_detections[image_id], - total_time=batch_elapsed / len(image_ids), # Approximate time per image ) + dt, t = t("Finished detection") + total_detection_time += dt + + # Group detections by image_id + image_detections: dict[str, list[DetectionResponse]] = { + img_id: [] for img_id in image_ids + } + image_tensors = dict(zip(image_ids, images, strict=True)) + + classifier.reset(detector.results) + + for idx, dresp in enumerate(detector.results): + image_tensor = image_tensors[dresp.source_image_id] + bbox = dresp.bbox + # crop the image tensor using the bbox + crop = image_tensor[ + :, int(bbox.y1) : int(bbox.y2), int(bbox.x1) : int(bbox.x2) + ] + crop = crop.unsqueeze(0) # add batch dimension + classifier_out = classifier.predict_batch(crop) + classifier_out = classifier.post_process_batch(classifier_out) + detection = classifier.update_detection_classification( + seconds_per_item=0, + image_id=dresp.source_image_id, + detection_idx=idx, + predictions=classifier_out[0], + ) + image_detections[dresp.source_image_id].append(detection) + all_detections.append(detection) + + ct, t = t("Finished classification") + total_classification_time += ct + + # Calculate batch processing time + batch_end_time = datetime.datetime.now() + batch_elapsed = (batch_end_time - batch_start_time).total_seconds() + + # Post results back to the API with PipelineResponse for each image + batch_results: list[AntennaTaskResult] = [] + for reply_subject, image_id, image_url in zip( + reply_subjects, image_ids, image_urls, strict=True + ): + # Create SourceImageResponse for this image + source_image = SourceImageResponse(id=image_id, url=image_url) + + # Create PipelineResultsResponse + pipeline_response = PipelineResultsResponse( + pipeline=pipeline, + source_images=[source_image], + detections=image_detections[image_id], + total_time=batch_elapsed + / len(image_ids), # Approximate time per image + ) - batch_results.append( - AntennaTaskResult( - reply_subject=reply_subject, - result=pipeline_response, + batch_results.append( + AntennaTaskResult( + reply_subject=reply_subject, + result=pipeline_response, + ) ) - ) + except Exception as e: + logger.error(f"Batch {i + 1} failed during processing: {e}", exc_info=True) + # Report errors back to Antenna so tasks aren't stuck in the queue + batch_results = [] + for reply_subject, image_id in zip(reply_subjects, image_ids): + batch_results.append( + AntennaTaskResult( + reply_subject=reply_subject, + result=AntennaTaskResultError( + error=f"Batch processing error: {e}", + image_id=image_id, + ), + ) + ) + failed_items = batch.get("failed_items") if failed_items: for failed_item in failed_items: @@ -236,12 +255,10 @@ def _process_job( st, t = t("Finished posting results") if not success: - error_msg = ( + logger.error( f"Failed to post {len(batch_results)} results for job {job_id} to " f"{settings.antenna_api_base_url}. Batch processing data lost." ) - logger.error(error_msg) - raise RuntimeError(error_msg) total_save_time += st From 7e289b432a8505b6a884824e9eb50bb4e2c7561f Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 9 Feb 2026 18:30:28 -0800 Subject: [PATCH 03/17] perf: batch classification crops in worker instead of N individual GPU calls Collect all detection crops, apply classifier transforms (which resize to uniform input_size), then run a single batched predict_batch call. Skips detections with invalid bounding boxes (y1 >= y2 or x1 >= x2). This supersedes PR #105. Co-Authored-By: Claude Opus 4.6 --- trapdata/antenna/worker.py | 48 +++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index f31e74e..2fbec4f 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -5,6 +5,7 @@ import numpy as np import torch +import torchvision from trapdata.antenna.client import get_jobs, post_batch_results from trapdata.antenna.datasets import get_rest_dataloader @@ -168,25 +169,44 @@ def _process_job( image_tensors = dict(zip(image_ids, images, strict=True)) classifier.reset(detector.results) + to_pil = torchvision.transforms.ToPILImage() + classify_transforms = classifier.get_transforms() + # Collect and transform all crops for batched classification + crops = [] + valid_indices = [] for idx, dresp in enumerate(detector.results): image_tensor = image_tensors[dresp.source_image_id] bbox = dresp.bbox - # crop the image tensor using the bbox - crop = image_tensor[ - :, int(bbox.y1) : int(bbox.y2), int(bbox.x1) : int(bbox.x2) - ] - crop = crop.unsqueeze(0) # add batch dimension - classifier_out = classifier.predict_batch(crop) + y1, y2 = int(bbox.y1), int(bbox.y2) + x1, x2 = int(bbox.x1), int(bbox.x2) + if y1 >= y2 or x1 >= x2: + logger.warning( + f"Skipping detection {idx} with invalid bbox: " + f"({x1},{y1})->({x2},{y2})" + ) + continue + crop = image_tensor[:, y1:y2, x1:x2] + crop_pil = to_pil(crop) + crop_transformed = classify_transforms(crop_pil) + crops.append(crop_transformed) + valid_indices.append(idx) + + if crops: + batched_crops = torch.stack(crops) + classifier_out = classifier.predict_batch(batched_crops) classifier_out = classifier.post_process_batch(classifier_out) - detection = classifier.update_detection_classification( - seconds_per_item=0, - image_id=dresp.source_image_id, - detection_idx=idx, - predictions=classifier_out[0], - ) - image_detections[dresp.source_image_id].append(detection) - all_detections.append(detection) + + for crop_i, idx in enumerate(valid_indices): + dresp = detector.results[idx] + detection = classifier.update_detection_classification( + seconds_per_item=0, + image_id=dresp.source_image_id, + detection_idx=idx, + predictions=classifier_out[crop_i], + ) + image_detections[dresp.source_image_id].append(detection) + all_detections.append(detection) ct, t = t("Finished classification") total_classification_time += ct From 23a90e3c40eee038483ddedef7696f6f6db07f0b Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 9 Feb 2026 18:51:31 -0800 Subject: [PATCH 04/17] perf: increase default localization_batch_size and num_workers for GPU MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit localization_batch_size 2 → 8 and num_workers 1 → 4. The old defaults were far too conservative for 24GB VRAM GPUs. These can still be overridden via AMI_LOCALIZATION_BATCH_SIZE and AMI_NUM_WORKERS env vars. Co-Authored-By: Claude Opus 4.6 --- trapdata/settings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trapdata/settings.py b/trapdata/settings.py index f4b83f1..f7286f0 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -33,9 +33,9 @@ class Settings(BaseSettings): default=ml.models.DEFAULT_FEATURE_EXTRACTOR ) classification_threshold: float = 0.6 - localization_batch_size: int = 2 + localization_batch_size: int = 8 classification_batch_size: int = 20 - num_workers: int = 1 + num_workers: int = 4 # Antenna API worker settings antenna_api_base_url: str = "http://localhost:8000/api/v2" From 8557298da6cb22e2da28cd25d84258a7a611a286 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 9 Feb 2026 18:51:47 -0800 Subject: [PATCH 05/17] feat: spawn one worker process per GPU for multi-GPU inference run_worker() now detects torch.cuda.device_count() and uses torch.multiprocessing.spawn to launch one worker per GPU. Each worker pins itself to a specific GPU via set_device(). Single-GPU and CPU-only machines keep existing single-process behavior with no overhead. Also fixes get_device() to use current_device() instead of bare "cuda" so that models load onto the correct GPU in spawned workers. Co-Authored-By: Claude Opus 4.6 --- trapdata/antenna/worker.py | 56 ++++++++++++++++++++++++++++++++++---- trapdata/ml/utils.py | 10 +++++-- 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index 2fbec4f..318bdd9 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -5,6 +5,7 @@ import numpy as np import torch +import torch.multiprocessing as mp import torchvision from trapdata.antenna.client import get_jobs, post_batch_results @@ -25,7 +26,11 @@ def run_worker(pipelines: list[str]): - """Run the worker to process images from the REST API queue.""" + """Run the worker to process images from the REST API queue. + + Automatically spawns one worker process per available GPU. + On single-GPU or CPU-only machines, runs in-process (no overhead). + """ settings = read_settings() # Validate auth token @@ -35,20 +40,57 @@ def run_worker(pipelines: list[str]): "Get your auth token from your Antenna project settings." ) + gpu_count = torch.cuda.device_count() + + if gpu_count > 1: + logger.info(f"Found {gpu_count} GPUs, spawning one worker per GPU") + # Don't pass settings through mp.spawn — Settings contains enums that + # can't be pickled. Each child process calls read_settings() itself. + mp.spawn( + _worker_loop, + args=(pipelines,), + nprocs=gpu_count, + join=True, + ) + else: + if gpu_count == 1: + logger.info(f"Found 1 GPU: {torch.cuda.get_device_name(0)}") + else: + logger.info("No GPUs found, running on CPU") + _worker_loop(0, pipelines) + + +def _worker_loop(gpu_id: int, pipelines: list[str]): + """Main polling loop for a single worker, pinned to a specific GPU. + + Args: + gpu_id: GPU index to pin this worker to (0 for CPU-only). + pipelines: List of pipeline slugs to poll for jobs. + """ + settings = read_settings() + + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + torch.cuda.set_device(gpu_id) + logger.info( + f"Worker {gpu_id} pinned to GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}" + ) + while True: # TODO CGJS: Support pulling and prioritizing single image tasks, which are used in interactive testing # These should probably come from a dedicated endpoint and should preempt batch jobs under the assumption that they # would run on the same GPU. any_jobs = False for pipeline in pipelines: - logger.info(f"Checking for jobs for pipeline {pipeline}") + logger.info(f"[GPU {gpu_id}] Checking for jobs for pipeline {pipeline}") jobs = get_jobs( base_url=settings.antenna_api_base_url, auth_token=settings.antenna_api_auth_token, pipeline_slug=pipeline, ) for job_id in jobs: - logger.info(f"Processing job {job_id} with pipeline {pipeline}") + logger.info( + f"[GPU {gpu_id}] Processing job {job_id} with pipeline {pipeline}" + ) try: any_work_done = _process_job( pipeline=pipeline, @@ -58,13 +100,15 @@ def run_worker(pipelines: list[str]): any_jobs = any_jobs or any_work_done except Exception as e: logger.error( - f"Failed to process job {job_id} with pipeline {pipeline}: {e}", + f"[GPU {gpu_id}] Failed to process job {job_id} with pipeline {pipeline}: {e}", exc_info=True, ) # Continue to next job rather than crashing the worker if not any_jobs: - logger.info(f"No jobs found, sleeping for {SLEEP_TIME_SECONDS} seconds") + logger.info( + f"[GPU {gpu_id}] No jobs found, sleeping for {SLEEP_TIME_SECONDS} seconds" + ) time.sleep(SLEEP_TIME_SECONDS) @@ -139,7 +183,7 @@ def _process_job( # Track start time for this batch batch_start_time = datetime.datetime.now() - logger.info(f"Processing batch {i + 1}") + logger.info(f"Processing worker batch {i + 1} ({len(images)} images)") # output is dict of "boxes", "labels", "scores" batch_output = [] if len(images) > 0: diff --git a/trapdata/ml/utils.py b/trapdata/ml/utils.py index 3d52067..da09746 100644 --- a/trapdata/ml/utils.py +++ b/trapdata/ml/utils.py @@ -42,8 +42,14 @@ def get_device(device_str=None) -> torch.device: @TODO check Kivy settings to see if user forced use of CPU """ if not device_str: - device_str = "cuda" if torch.cuda.is_available() else "cpu" - device = torch.device(device_str) + if torch.cuda.is_available(): + # Use current_device() so mp.spawn workers that called + # torch.cuda.set_device(i) get the correct GPU index. + device = torch.device("cuda", torch.cuda.current_device()) + else: + device = torch.device("cpu") + else: + device = torch.device(device_str) logger.info(f"Using device '{device}' for inference") return device From 0627bd7664d9878f529308b27549214a6e4435ea Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 9 Feb 2026 18:51:54 -0800 Subject: [PATCH 06/17] fix: clarify batch size log messages to distinguish worker vs inference Worker batch logs now show image count ("Processing worker batch 3 (8 images)") and model inference logs include the model name ("Preparing FasterRCNN inference dataloader (batch_size=4, single worker mode)"). Co-Authored-By: Claude Opus 4.6 --- trapdata/ml/models/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trapdata/ml/models/base.py b/trapdata/ml/models/base.py index bb7d1fa..b086b90 100644 --- a/trapdata/ml/models/base.py +++ b/trapdata/ml/models/base.py @@ -244,11 +244,11 @@ def get_dataloader(self): """ if self.single: logger.info( - f"Preparing dataloader with batch size of {self.batch_size} in single worker mode." + f"Preparing {self.name} inference dataloader (batch_size={self.batch_size}, single worker mode)" ) else: logger.info( - f"Preparing dataloader with batch size of {self.batch_size} and {self.num_workers} workers." + f"Preparing {self.name} inference dataloader (batch_size={self.batch_size}, num_workers={self.num_workers})" ) dataloader_args = { "num_workers": 0 if self.single else self.num_workers, From 123b24e4f1e544135b7564b8d68796ec6ef2b421 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 9 Feb 2026 18:54:35 -0800 Subject: [PATCH 07/17] perf: increase DataLoader prefetch_factor to 4 for worker The default prefetch_factor of 2 means the DataLoader only prepares 2 batches ahead. With GPU inference taking ~2s and image downloads taking ~30s per batch, the GPU idles waiting for data. Bumping to 4 keeps the download pipeline fuller so the next batch is more likely to be ready. Co-Authored-By: Claude Opus 4.6 --- trapdata/antenna/datasets.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 22801c9..54c53fa 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -272,9 +272,14 @@ def get_rest_dataloader( batch_size=settings.antenna_api_batch_size, ) - return torch.utils.data.DataLoader( - dataset, - batch_size=settings.localization_batch_size, - num_workers=settings.num_workers, - collate_fn=rest_collate_fn, - ) + dataloader_kwargs: dict = { + "batch_size": settings.localization_batch_size, + "num_workers": settings.num_workers, + "collate_fn": rest_collate_fn, + } + if settings.num_workers > 0: + # Prefetch more batches so the next batch is already downloading + # while the GPU processes the current one. Default is 2. + dataloader_kwargs["prefetch_factor"] = 4 + + return torch.utils.data.DataLoader(dataset, **dataloader_kwargs) From 96b19f60ba5454c5b36c6805187609df71e5d22c Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 9 Feb 2026 18:55:59 -0800 Subject: [PATCH 08/17] perf: download images concurrently within each DataLoader worker Previously each DataLoader worker downloaded images sequentially: fetch task metadata, then download image 1, download image 2, ... download image N, yield all. With 32 tasks per API fetch, this meant ~30s of serial HTTP requests before a single image reached the GPU. Now _load_images_threaded() uses a ThreadPoolExecutor (up to 8 threads) to download all images in a task batch concurrently. Threads are ideal here because image downloads are I/O-bound (network latency), not CPU-bound, and the requests Session's connection pool is thread-safe. This stacks with the existing DataLoader num_workers parallelism: - num_workers: N independent DataLoader processes, each polling the API - ThreadPoolExecutor: within each process, M concurrent image downloads - prefetch_factor: DataLoader queues future batches while GPU is busy Expected improvement: download time for a batch of 32 images drops from ~30s (sequential) to ~4-5s (8 concurrent threads), keeping the GPU fed. Co-Authored-By: Claude Opus 4.6 --- trapdata/antenna/datasets.py | 53 ++++++++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 8 deletions(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 54c53fa..590b1eb 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -1,6 +1,7 @@ """Dataset classes for streaming tasks from the Antenna API.""" import typing +from concurrent.futures import ThreadPoolExecutor, as_completed from io import BytesIO import requests @@ -120,17 +121,55 @@ def _load_image(self, image_url: str) -> torch.Tensor | None: logger.error(f"Failed to load image from {image_url}: {e}") return None + def _load_images_threaded( + self, + tasks: list[AntennaPipelineProcessingTask], + ) -> dict[str, torch.Tensor | None]: + """Download images for a batch of tasks using concurrent threads. + + Image downloads are I/O-bound (network latency, not CPU), so threads + provide near-linear speedup without the overhead of extra processes. + The HTTP session's connection pool is thread-safe and reuses TCP + connections across threads. + + Args: + tasks: List of tasks whose images should be downloaded. + + Returns: + Mapping from image_id to tensor (or None on failure), preserving + the order needed by the caller. + """ + results: dict[str, torch.Tensor | None] = {} + + def _download( + task: AntennaPipelineProcessingTask, + ) -> tuple[str, torch.Tensor | None]: + tensor = self._load_image(task.image_url) if task.image_url else None + return (task.image_id, tensor) + + max_threads = min(len(tasks), 8) + with ThreadPoolExecutor(max_workers=max_threads) as executor: + futures = {executor.submit(_download, t): t for t in tasks} + for future in as_completed(futures): + image_id, tensor = future.result() + results[image_id] = tensor + + return results + def __iter__(self): """ Iterate over tasks from the REST API. + Each API fetch returns a batch of tasks. Images for the entire batch + are downloaded concurrently using threads (see _load_images_threaded), + then yielded one at a time for the DataLoader to collate. + Yields: Dictionary containing: - image: PyTorch tensor of the loaded image - reply_subject: Reply subject for the task - - batch_index: Index of the image in the batch - - job_id: Job ID - image_id: Image ID + - image_url: Source URL """ worker_id = 0 # Initialize before try block to avoid UnboundLocalError try: @@ -160,14 +199,12 @@ def __iter__(self): ) break + # Download all images concurrently + image_map = self._load_images_threaded(tasks) + for task in tasks: + image_tensor = image_map.get(task.image_id) errors = [] - # Load the image - # _, t = log_time() - image_tensor = ( - self._load_image(task.image_url) if task.image_url else None - ) - # _, t = t(f"Loaded image from {image_url}") if image_tensor is None: errors.append("failed to load image") From 57bfed2ccdc6faa2192c97fdac3af46cedecc1e2 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 12 Feb 2026 14:43:19 -0800 Subject: [PATCH 09/17] docs: clarify types of workers --- trapdata/antenna/datasets.py | 14 +++++++------- trapdata/antenna/worker.py | 10 +++++----- trapdata/ml/models/base.py | 2 +- trapdata/settings.py | 7 +++++-- 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 590b1eb..f3546c3 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -33,9 +33,9 @@ class RESTDataset(torch.utils.data.IterableDataset): DataLoader workers are SAFE and won't process duplicate tasks. Each worker independently fetches different tasks from the shared queue. - With num_workers > 0: - Worker 1: GET /tasks → receives [1,2,3,4], removed from queue - Worker 2: GET /tasks → receives [5,6,7,8], removed from queue + With DataLoader num_workers > 0 (I/O subprocesses, not AMI instances): + Subprocess 1: GET /tasks → receives [1,2,3,4], removed from queue + Subprocess 2: GET /tasks → receives [5,6,7,8], removed from queue No duplicates, safe for parallel processing """ @@ -179,7 +179,7 @@ def __iter__(self): num_workers = worker_info.num_workers if worker_info else 1 logger.info( - f"Worker {worker_id}/{num_workers} starting iteration for job {self.job_id}" + f"DataLoader subprocess {worker_id}/{num_workers} starting iteration for job {self.job_id}" ) while True: @@ -292,10 +292,10 @@ def get_rest_dataloader( """ Create a DataLoader that fetches tasks from Antenna API. - Note: num_workers > 0 is SAFE here (unlike local file reading) because: + Note: DataLoader num_workers > 0 is SAFE here (unlike local file reading) because: - Antenna API provides atomic task dequeue (work queue pattern) - - No shared file handles between workers - - Each worker gets different tasks automatically + - No shared file handles between subprocesses + - Each subprocess gets different tasks automatically - Parallel downloads improve throughput for I/O-bound work Args: diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index 318bdd9..bb75f5f 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -28,7 +28,7 @@ def run_worker(pipelines: list[str]): """Run the worker to process images from the REST API queue. - Automatically spawns one worker process per available GPU. + Automatically spawns one AMI worker instance process per available GPU. On single-GPU or CPU-only machines, runs in-process (no overhead). """ settings = read_settings() @@ -43,7 +43,7 @@ def run_worker(pipelines: list[str]): gpu_count = torch.cuda.device_count() if gpu_count > 1: - logger.info(f"Found {gpu_count} GPUs, spawning one worker per GPU") + logger.info(f"Found {gpu_count} GPUs, spawning one AMI worker instance per GPU") # Don't pass settings through mp.spawn — Settings contains enums that # can't be pickled. Each child process calls read_settings() itself. mp.spawn( @@ -61,10 +61,10 @@ def run_worker(pipelines: list[str]): def _worker_loop(gpu_id: int, pipelines: list[str]): - """Main polling loop for a single worker, pinned to a specific GPU. + """Main polling loop for a single AMI worker instance, pinned to a specific GPU. Args: - gpu_id: GPU index to pin this worker to (0 for CPU-only). + gpu_id: GPU index to pin this AMI worker instance to (0 for CPU-only). pipelines: List of pipeline slugs to poll for jobs. """ settings = read_settings() @@ -72,7 +72,7 @@ def _worker_loop(gpu_id: int, pipelines: list[str]): if torch.cuda.is_available() and torch.cuda.device_count() > 0: torch.cuda.set_device(gpu_id) logger.info( - f"Worker {gpu_id} pinned to GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}" + f"AMI worker instance {gpu_id} pinned to GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}" ) while True: diff --git a/trapdata/ml/models/base.py b/trapdata/ml/models/base.py index b086b90..1c694a7 100644 --- a/trapdata/ml/models/base.py +++ b/trapdata/ml/models/base.py @@ -248,7 +248,7 @@ def get_dataloader(self): ) else: logger.info( - f"Preparing {self.name} inference dataloader (batch_size={self.batch_size}, num_workers={self.num_workers})" + f"Preparing {self.name} inference dataloader (batch_size={self.batch_size}, dataloader_workers={self.num_workers})" ) dataloader_args = { "num_workers": 0 if self.single else self.num_workers, diff --git a/trapdata/settings.py b/trapdata/settings.py index f7286f0..ee3c9ba 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -143,8 +143,11 @@ class Config: "kivy_section": "performance", }, "num_workers": { - "title": "Number of workers", - "description": "Number of parallel workers for the PyTorch dataloader. See https://pytorch.org/docs/stable/data.html", + "title": "DataLoader workers", + "description": ( + "Number of parallel subprocesses for the PyTorch DataLoader (image downloading & preprocessing). " + "See https://pytorch.org/docs/stable/data.html" + ), "kivy_type": "numeric", "kivy_section": "performance", }, From d9ae653f8a087c289ccc6a989ba6ffe559e2b287 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 12 Feb 2026 16:22:10 -0800 Subject: [PATCH 10/17] refactor: make ThreadPoolExecutor a class member, simplify with map() - Keep the thread pool alive across batches instead of recreating it - Use executor.map() instead of submit + as_completed - Fix docstring: requests.Session is not formally thread-safe Addresses review comments from @carlosgjs. Co-Authored-By: Claude Opus 4.6 --- trapdata/antenna/datasets.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index f3546c3..93d1c6c 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -1,7 +1,7 @@ """Dataset classes for streaming tasks from the Antenna API.""" import typing -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ThreadPoolExecutor from io import BytesIO import requests @@ -67,8 +67,13 @@ def __init__( self.api_session = get_http_session(auth_token) self.image_fetch_session = get_http_session() # No auth for external image URLs + # Reusable thread pool for concurrent image downloads + self._executor = ThreadPoolExecutor(max_workers=8) + def __del__(self): - """Clean up HTTP sessions on dataset destruction.""" + """Clean up HTTP sessions and thread pool on dataset destruction.""" + if hasattr(self, "_executor"): + self._executor.shutdown(wait=False) if hasattr(self, "api_session"): self.api_session.close() if hasattr(self, "image_fetch_session"): @@ -129,8 +134,10 @@ def _load_images_threaded( Image downloads are I/O-bound (network latency, not CPU), so threads provide near-linear speedup without the overhead of extra processes. - The HTTP session's connection pool is thread-safe and reuses TCP - connections across threads. + Note: ``requests.Session`` is not formally thread-safe, but the + underlying urllib3 connection pool handles concurrent socket access. + In practice shared read-only sessions work fine for GET requests; + if issues arise, switch to per-thread sessions. Args: tasks: List of tasks whose images should be downloaded. @@ -139,7 +146,6 @@ def _load_images_threaded( Mapping from image_id to tensor (or None on failure), preserving the order needed by the caller. """ - results: dict[str, torch.Tensor | None] = {} def _download( task: AntennaPipelineProcessingTask, @@ -147,14 +153,7 @@ def _download( tensor = self._load_image(task.image_url) if task.image_url else None return (task.image_id, tensor) - max_threads = min(len(tasks), 8) - with ThreadPoolExecutor(max_workers=max_threads) as executor: - futures = {executor.submit(_download, t): t for t in tasks} - for future in as_completed(futures): - image_id, tensor = future.result() - results[image_id] = tensor - - return results + return dict(self._executor.map(_download, tasks)) def __iter__(self): """ From 7abad401de4750d6e31ca3042f57c7a9b1adb889 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 12 Feb 2026 16:22:26 -0800 Subject: [PATCH 11/17] docs: fix stale docstring in rest_collate_fn MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit "Stacked tensor" → "List of image tensors" to match the actual return type after the variable-size image change. Co-Authored-By: Claude Opus 4.6 --- trapdata/antenna/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 93d1c6c..8e0f619 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -235,7 +235,7 @@ def rest_collate_fn(batch: list[dict]) -> dict: Custom collate function that separates failed and successful items. Returns a dict with: - - images: Stacked tensor of valid images (only present if there are successful items) + - images: List of image tensors (only present if there are successful items) - reply_subjects: List of reply subjects for valid images - image_ids: List of image IDs for valid images - image_urls: List of image URLs for valid images From 4baa3d7c12ef620db98673e63dc2b39f3000a4c0 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 12 Feb 2026 16:22:55 -0800 Subject: [PATCH 12/17] fix: defensive batch_results init and strict=True in error handler - Initialize batch_results before the try block to prevent potential NameError if a future refactor introduces a path between try/except and the post-results code. - Add strict=True to zip(reply_subjects, image_ids) in the except block so length mismatches raise immediately rather than silently dropping error reports. Co-Authored-By: Claude Opus 4.6 --- trapdata/antenna/worker.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index bb75f5f..bba7905 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -166,6 +166,8 @@ def _process_job( reply_subjects = batch.get("reply_subjects", [None] * len(images)) image_urls = batch.get("image_urls", [None] * len(images)) + batch_results: list[AntennaTaskResult] = [] + try: # Validate all arrays have same length before zipping if len(image_ids) != len(images): @@ -260,7 +262,7 @@ def _process_job( batch_elapsed = (batch_end_time - batch_start_time).total_seconds() # Post results back to the API with PipelineResponse for each image - batch_results: list[AntennaTaskResult] = [] + batch_results.clear() for reply_subject, image_id, image_url in zip( reply_subjects, image_ids, image_urls, strict=True ): @@ -286,7 +288,7 @@ def _process_job( logger.error(f"Batch {i + 1} failed during processing: {e}", exc_info=True) # Report errors back to Antenna so tasks aren't stuck in the queue batch_results = [] - for reply_subject, image_id in zip(reply_subjects, image_ids): + for reply_subject, image_id in zip(reply_subjects, image_ids, strict=True): batch_results.append( AntennaTaskResult( reply_subject=reply_subject, From 27d9da2a3669f965e3a462828efb0e41984d99f0 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 12 Feb 2026 16:23:13 -0800 Subject: [PATCH 13/17] revert: remove prefetch_factor=4 override, use PyTorch default No measured improvement over the default (2). The override just increases memory usage without demonstrated benefit. Co-Authored-By: Claude Opus 4.6 --- trapdata/antenna/datasets.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 8e0f619..3f57265 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -308,14 +308,9 @@ def get_rest_dataloader( batch_size=settings.antenna_api_batch_size, ) - dataloader_kwargs: dict = { - "batch_size": settings.localization_batch_size, - "num_workers": settings.num_workers, - "collate_fn": rest_collate_fn, - } - if settings.num_workers > 0: - # Prefetch more batches so the next batch is already downloading - # while the GPU processes the current one. Default is 2. - dataloader_kwargs["prefetch_factor"] = 4 - - return torch.utils.data.DataLoader(dataset, **dataloader_kwargs) + return torch.utils.data.DataLoader( + dataset, + batch_size=settings.localization_batch_size, + num_workers=settings.num_workers, + collate_fn=rest_collate_fn, + ) From 3319c6059c00964051ca4bef5e77830982ec16aa Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 12 Feb 2026 16:39:06 -0800 Subject: [PATCH 14/17] docs: document data loading pipeline concurrency layers Add a module-level overview mapping the three concurrency layers (GPU processes, DataLoader workers, thread pool) to their settings, what work each layer does, and what runs under the GIL. Includes a "not yet benchmarked" section so future contributors know which knobs are tuned empirically vs speculatively. Co-Authored-By: Claude Opus 4.6 --- trapdata/antenna/datasets.py | 90 +++++++++++++++++++++++++++++++----- 1 file changed, 79 insertions(+), 11 deletions(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 3f57265..d234c37 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -1,4 +1,64 @@ -"""Dataset classes for streaming tasks from the Antenna API.""" +"""Dataset and DataLoader for streaming tasks from the Antenna API. + +Data loading pipeline overview +============================== + +The pipeline has three layers of concurrency. Each layer is controlled by a +different setting and targets a different bottleneck. + +:: + + ┌──────────────────────────────────────────────────────────────────┐ + │ GPU process (_worker_loop in worker.py) │ + │ One per GPU. Runs detection → classification on batches. │ + │ Controlled by: automatic (one per torch.cuda.device_count()) │ + ├──────────────────────────────────────────────────────────────────┤ + │ DataLoader workers (num_workers subprocesses) │ + │ Each subprocess runs its own RESTDataset.__iter__ loop: │ + │ 1. GET /tasks → fetch batch of task metadata from Antenna │ + │ 2. Download images (threaded, see below) │ + │ 3. Yield individual (image_tensor, metadata) rows │ + │ The DataLoader collates rows into GPU-sized batches. │ + │ Controlled by: settings.num_workers (AMI_NUM_WORKERS) │ + │ Default: 4. Safe >0 because Antenna dequeues atomically. │ + ├──────────────────────────────────────────────────────────────────┤ + │ Thread pool (ThreadPoolExecutor inside each DataLoader worker) │ + │ Downloads images concurrently *within* one API fetch batch. │ + │ Each thread: HTTP GET → PIL open → RGB convert → ToTensor(). │ + │ Controlled by: ThreadPoolExecutor(max_workers=8) on the class. │ + │ Note: RGB conversion and ToTensor are GIL-bound (CPU). Only │ + │ the network wait truly runs in parallel. A future optimisation │ + │ could move transforms out of the thread. │ + └──────────────────────────────────────────────────────────────────┘ + +Settings quick-reference (prefix with AMI_ as env vars): + + localization_batch_size (default 8) + How many images the GPU processes at once (detection). Larger = + more GPU memory. These are full-resolution images (~4K). + + num_workers (default 4) + DataLoader subprocesses. Each independently fetches tasks and + downloads images. More workers = more images prefetched for the + GPU, at the cost of CPU/RAM. With 0 workers, fetching and + inference are sequential (useful for debugging). + + antenna_api_batch_size (default 4) + How many task URLs to request from Antenna per API call. + Determines how many images are downloaded concurrently per + thread pool invocation. + + prefetch_factor (PyTorch default: 2 when num_workers > 0) + Batches prefetched per worker. Not overridden here — the + default was tested and no improvement was measured by + increasing it (it just adds memory pressure). + +What has NOT been benchmarked yet (as of 2026-02): + - Optimal num_workers / thread count combination + - Whether moving transforms out of threads helps throughput + - Whether multiple DataLoader workers + threads overlap well + or contend on the GIL +""" import typing from concurrent.futures import ThreadPoolExecutor @@ -100,8 +160,12 @@ def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]: return tasks_response.tasks # Empty list is valid (queue drained) def _load_image(self, image_url: str) -> torch.Tensor | None: - """ - Load an image from a URL and convert it to a PyTorch tensor. + """Load an image from a URL and convert it to a PyTorch tensor. + + Called from threads inside ``_load_images_threaded``. The HTTP + fetch is truly concurrent (network I/O releases the GIL), but + PIL decode, RGB conversion, and ``image_transforms`` (ToTensor) + are CPU-bound and serialised by the GIL. Args: image_url: URL of the image to load @@ -288,18 +352,22 @@ def get_rest_dataloader( job_id: int, settings: "Settings", ) -> torch.utils.data.DataLoader: - """ - Create a DataLoader that fetches tasks from Antenna API. + """Create a DataLoader that fetches tasks from Antenna API. + + See the module docstring for an overview of the three concurrency + layers (GPU processes → DataLoader workers → thread pool) and which + settings control each. - Note: DataLoader num_workers > 0 is SAFE here (unlike local file reading) because: - - Antenna API provides atomic task dequeue (work queue pattern) - - No shared file handles between subprocesses - - Each subprocess gets different tasks automatically - - Parallel downloads improve throughput for I/O-bound work + DataLoader num_workers > 0 is safe here because Antenna dequeues + tasks atomically — each worker subprocess gets a unique set of tasks. Args: job_id: Job ID to fetch tasks for - settings: Settings object with antenna_api_* configuration + settings: Settings object. Relevant fields: + - antenna_api_base_url / antenna_api_auth_token + - antenna_api_batch_size (tasks per API call) + - localization_batch_size (images per GPU batch) + - num_workers (DataLoader subprocesses) """ dataset = RESTDataset( base_url=settings.antenna_api_base_url, From ed49153c2e5a338ee9a1ae2ef8ac15db4dc31876 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 12 Feb 2026 16:47:24 -0800 Subject: [PATCH 15/17] perf: tune defaults for higher GPU utilization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - antenna_api_batch_size: 4 → 16 (fetch enough tasks per API call to fill at least one GPU batch without an extra round trip) - num_workers: 4 → 2 (each worker prefetches more with larger API batches; fewer workers reduces CPU/RAM overhead) Co-Authored-By: Claude Opus 4.6 --- trapdata/antenna/datasets.py | 8 +++++--- trapdata/settings.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index d234c37..6dc0a49 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -37,16 +37,18 @@ How many images the GPU processes at once (detection). Larger = more GPU memory. These are full-resolution images (~4K). - num_workers (default 4) + num_workers (default 2) DataLoader subprocesses. Each independently fetches tasks and downloads images. More workers = more images prefetched for the GPU, at the cost of CPU/RAM. With 0 workers, fetching and inference are sequential (useful for debugging). - antenna_api_batch_size (default 4) + antenna_api_batch_size (default 16) How many task URLs to request from Antenna per API call. Determines how many images are downloaded concurrently per - thread pool invocation. + thread pool invocation. Should be >= localization_batch_size + so one API call can fill at least one GPU batch without an + extra round trip. prefetch_factor (PyTorch default: 2 when num_workers > 0) Batches prefetched per worker. Not overridden here — the diff --git a/trapdata/settings.py b/trapdata/settings.py index ee3c9ba..0020bd5 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -35,12 +35,12 @@ class Settings(BaseSettings): classification_threshold: float = 0.6 localization_batch_size: int = 8 classification_batch_size: int = 20 - num_workers: int = 4 + num_workers: int = 2 # Antenna API worker settings antenna_api_base_url: str = "http://localhost:8000/api/v2" antenna_api_auth_token: str = "" - antenna_api_batch_size: int = 4 + antenna_api_batch_size: int = 16 @pydantic.field_validator("image_base_path", "user_data_path") def validate_path(cls, v): From 5163e4433e4284eb0df59778b965a24d9de8923f Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 16 Feb 2026 17:59:44 -0800 Subject: [PATCH 16/17] fix: lazily init unpicklable objects in RESTDataset for num_workers>0 ThreadPoolExecutor (SimpleQueue), requests.Session objects created in __init__ cannot be pickled. PyTorch DataLoader with num_workers>0 uses spawn, which pickles the dataset to send to worker subprocesses. Move session and executor creation to _ensure_sessions(), called on first use in each worker process. Add regression test that creates a DataLoader with num_workers=2 and verifies workers can start. Co-Authored-By: Claude Opus 4.6 --- trapdata/antenna/datasets.py | 45 ++++++++++++++++++++------- trapdata/antenna/tests/test_worker.py | 23 ++++++++++++++ 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 6dc0a49..492f796 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -121,25 +121,40 @@ def __init__( """ super().__init__() self.base_url = base_url + self.auth_token = auth_token self.job_id = job_id self.batch_size = batch_size self.image_transforms = image_transforms or torchvision.transforms.ToTensor() - # Create persistent sessions for connection pooling - self.api_session = get_http_session(auth_token) - self.image_fetch_session = get_http_session() # No auth for external image URLs + # These are created lazily in _ensure_sessions() because they contain + # unpicklable objects (ThreadPoolExecutor has a SimpleQueue) and + # PyTorch DataLoader with num_workers>0 pickles the dataset to send + # it to worker subprocesses. + self._api_session: requests.Session | None = None + self._image_fetch_session: requests.Session | None = None + self._executor: ThreadPoolExecutor | None = None - # Reusable thread pool for concurrent image downloads - self._executor = ThreadPoolExecutor(max_workers=8) + def _ensure_sessions(self) -> None: + """Lazily create HTTP sessions and thread pool. + + Called once per worker process on first use. This avoids pickling + issues with num_workers > 0 (SimpleQueue, socket objects, etc.). + """ + if self._api_session is None: + self._api_session = get_http_session(self.auth_token) + if self._image_fetch_session is None: + self._image_fetch_session = get_http_session() + if self._executor is None: + self._executor = ThreadPoolExecutor(max_workers=8) def __del__(self): """Clean up HTTP sessions and thread pool on dataset destruction.""" - if hasattr(self, "_executor"): + if self._executor is not None: self._executor.shutdown(wait=False) - if hasattr(self, "api_session"): - self.api_session.close() - if hasattr(self, "image_fetch_session"): - self.image_fetch_session.close() + if self._api_session is not None: + self._api_session.close() + if self._image_fetch_session is not None: + self._image_fetch_session.close() def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]: """ @@ -154,7 +169,9 @@ def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]: url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks" params = {"batch": self.batch_size} - response = self.api_session.get(url, params=params, timeout=30) + self._ensure_sessions() + assert self._api_session is not None + response = self._api_session.get(url, params=params, timeout=30) response.raise_for_status() # Parse and validate response with Pydantic @@ -177,7 +194,9 @@ def _load_image(self, image_url: str) -> torch.Tensor | None: """ try: # Use dedicated session without auth for external images - response = self.image_fetch_session.get(image_url, timeout=30) + self._ensure_sessions() + assert self._image_fetch_session is not None + response = self._image_fetch_session.get(image_url, timeout=30) response.raise_for_status() image = Image.open(BytesIO(response.content)) @@ -219,6 +238,8 @@ def _download( tensor = self._load_image(task.image_url) if task.image_url else None return (task.image_id, tensor) + self._ensure_sessions() + assert self._executor is not None return dict(self._executor.map(_download, tasks)) def __iter__(self): diff --git a/trapdata/antenna/tests/test_worker.py b/trapdata/antenna/tests/test_worker.py index 0a3b8dd..e9919ca 100644 --- a/trapdata/antenna/tests/test_worker.py +++ b/trapdata/antenna/tests/test_worker.py @@ -34,6 +34,29 @@ # --------------------------------------------------------------------------- +class TestDataLoaderMultiWorker(TestCase): + """DataLoader with num_workers > 0 must be able to start workers.""" + + def test_dataloader_starts_with_num_workers(self): + """Creating an iterator pickles the dataset to send to worker subprocesses.""" + dataset = RESTDataset( + base_url="http://localhost:1/api/v2", + auth_token="test-token", + job_id=1, + batch_size=4, + ) + loader = torch.utils.data.DataLoader( + dataset, + batch_size=2, + num_workers=2, + collate_fn=rest_collate_fn, + ) + # iter() pickles the dataset and spawns workers. + # If the dataset has unpicklable attributes this raises TypeError. + it = iter(loader) + del it + + class TestRestCollateFn(TestCase): """Tests for rest_collate_fn which separates successful/failed items.""" From 0857522ee52af869472b4e64eb45e95b8d51e0c7 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 16 Feb 2026 20:07:35 -0800 Subject: [PATCH 17/17] =?UTF-8?q?docs:=20fix=20stale=20num=5Fworkers=20def?= =?UTF-8?q?ault=20in=20datasets.py=20docstring=20(4=20=E2=86=92=202)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Matches the actual default changed in ed49153. Co-Authored-By: Claude Opus 4.6 --- trapdata/antenna/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 492f796..25b33ba 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -20,7 +20,7 @@ │ 3. Yield individual (image_tensor, metadata) rows │ │ The DataLoader collates rows into GPU-sized batches. │ │ Controlled by: settings.num_workers (AMI_NUM_WORKERS) │ - │ Default: 4. Safe >0 because Antenna dequeues atomically. │ + │ Default: 2. Safe >0 because Antenna dequeues atomically. │ ├──────────────────────────────────────────────────────────────────┤ │ Thread pool (ThreadPoolExecutor inside each DataLoader worker) │ │ Downloads images concurrently *within* one API fetch batch. │