Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b60eab0
merge
carlos-irreverentlabs Jan 16, 2026
644927f
Merge remote-tracking branch 'upstream/main'
carlosgjs Jan 22, 2026
218f7aa
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 3, 2026
90da389
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 10, 2026
8618d3c
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 13, 2026
bd1be5f
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 17, 2026
b102ae1
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 19, 2026
bc908aa
fix: PSv2 follow-up fixes from integration tests (#1135)
mihow Feb 21, 2026
4c3802a
PSv2: Improve task fetching & web worker concurrency configuration (#…
carlosgjs Feb 21, 2026
b717e80
fix: include pipeline_slug in MinimalJobSerializer (#1148)
mihow Feb 21, 2026
883c4f8
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 24, 2026
8df89be
Avoid redis based locking by using atomic updates
carlosgjs Feb 24, 2026
e26f3c6
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 24, 2026
1096fd9
Merge branch 'main' into carlosg/redisatomic
carlosgjs Feb 24, 2026
30c8db3
Test concurrency
carlosgjs Feb 25, 2026
deea095
Increase max ack pending
carlosgjs Feb 25, 2026
20c0fbd
update comment
carlosgjs Feb 25, 2026
e84421e
CR feedback
carlosgjs Feb 25, 2026
d591bd6
CR feedback
carlosgjs Feb 25, 2026
4720bb6
CR 2
carlosgjs Feb 26, 2026
f0cd403
fix: OrderedEnum comparisons now override str MRO in subclasses
mihow Feb 26, 2026
e3134a1
fix: correct misleading error log about NATS redelivery
mihow Feb 26, 2026
41b1232
Merge branch 'carlosg/redisatomic' of github.com:uw-ssec/antenna into…
carlosgjs Feb 26, 2026
dcf57fe
Use job.logger
carlosgjs Feb 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,10 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub

state_manager = AsyncJobStateManager(job_id)

progress_info = state_manager.update_state(
processed_image_ids, stage="process", request_id=self.request.id, failed_image_ids=failed_image_ids
)
progress_info = state_manager.update_state(processed_image_ids, stage="process", failed_image_ids=failed_image_ids)
if not progress_info:
logger.warning(
f"Another task is already processing results for job {job_id}. "
f"Retrying task {self.request.id} in 5 seconds..."
)
raise self.retry(countdown=5, max_retries=10)
logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.")
Comment thread
carlosgjs marked this conversation as resolved.
return
Comment thread
carlosgjs marked this conversation as resolved.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When Redis state is missing, the job stays stuck in STARTED forever (flagged by both Copilot and CodeRabbit). Suggestion: mark the job as FAILURE before returning.

Suggested change
logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.")
# Acknowledge the task to prevent retries, since we don't know the state
_ack_task_via_nats(reply_subject, logger)
# TODO: cancel the job to fail fast once PR #1144 is merged
return
logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.")
_ack_task_via_nats(reply_subject, logger)
_fail_job(job_id, "Redis state missing during process stage")
return


try:
complete_state = JobState.SUCCESS
Expand Down Expand Up @@ -150,15 +145,11 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
progress_info = state_manager.update_state(
processed_image_ids,
stage="results",
request_id=self.request.id,
)

if not progress_info:
logger.warning(
f"Another task is already processing results for job {job_id}. "
f"Retrying task {self.request.id} in 5 seconds..."
)
raise self.retry(countdown=5, max_retries=10)
logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.")
Comment thread
carlosgjs marked this conversation as resolved.
Outdated
return
Comment thread
carlosgjs marked this conversation as resolved.

# update complete state based on latest progress info after saving results
complete_state = JobState.SUCCESS
Expand Down Expand Up @@ -256,6 +247,15 @@ def _update_job_progress(
state_params["classifications"] = current_classifications + new_classifications
state_params["captures"] = current_captures + new_captures

# Don't overwrite a stage with a stale progress value.
# This guards against the race where a slower worker calls _update_job_progress
# after a faster worker has already marked further progress
try:
existing_stage = job.progress.get_stage(stage)
progress_percentage = max(existing_stage.progress, progress_percentage)
except (ValueError, AttributeError):
pass # Stage doesn't exist yet; proceed normally
Comment thread
carlosgjs marked this conversation as resolved.

job.progress.update_stage(
stage,
status=complete_state if progress_percentage >= 1.0 else JobState.STARTED,
Comment thread
carlosgjs marked this conversation as resolved.
Outdated
Expand Down
57 changes: 31 additions & 26 deletions ami/jobs/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ami.jobs.tasks import process_nats_pipeline_result
from ami.main.models import Detection, Project, SourceImage, SourceImageCollection
from ami.ml.models import Pipeline
from ami.ml.orchestration.async_job_state import AsyncJobStateManager, _lock_key
from ami.ml.orchestration.async_job_state import AsyncJobStateManager
from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse, SourceImageResponse
from ami.users.models import User

Expand Down Expand Up @@ -237,38 +237,43 @@ def test_process_nats_pipeline_result_mixed_results(self, mock_manager_class):
self.assertEqual(mock_manager.acknowledge_task.call_count, 3)

@patch("ami.jobs.tasks.TaskQueueManager")
def test_process_nats_pipeline_result_error_concurrent_locking(self, mock_manager_class):
def test_process_nats_pipeline_result_concurrent_updates(self, mock_manager_class):
"""
Test that error results respect locking mechanism.
Test that concurrent workers update state independently without contention.

Verifies race condition handling when multiple workers
process error results simultaneously.
Without a lock, two workers processing different images can both call
Comment thread
carlosgjs marked this conversation as resolved.
update_state and receive valid progress — no retry needed, no blocking.
"""
# Simulate lock held by another task
lock_key = _lock_key(self.job.pk)
cache.set(lock_key, "other-task-id", timeout=60)
mock_manager = self._setup_mock_nats(mock_manager_class)

# Create error result
error_data = self._create_error_result(image_id=str(self.images[0].pk))
reply_subject = "tasks.reply.test789"

# Task should raise retry exception when lock not acquired
# The task internally calls self.retry() which raises a Retry exception
from celery.exceptions import Retry

with self.assertRaises(Retry):
process_nats_pipeline_result.apply(
kwargs={
"job_id": self.job.pk,
"result_data": error_data,
"reply_subject": reply_subject,
}
)
# Worker 1 processes images[0]
result_1 = process_nats_pipeline_result.apply(
kwargs={
"job_id": self.job.pk,
"result_data": self._create_error_result(image_id=str(self.images[0].pk)),
"reply_subject": "reply.concurrent.1",
}
)

# Worker 2 processes images[1] — no retry, no lock to wait for
result_2 = process_nats_pipeline_result.apply(
kwargs={
"job_id": self.job.pk,
"result_data": self._create_error_result(image_id=str(self.images[1].pk)),
"reply_subject": "reply.concurrent.2",
}
)

Comment thread
coderabbitai[bot] marked this conversation as resolved.
self.assertTrue(result_1.successful())
self.assertTrue(result_2.successful())

# Assert: Progress was NOT updated (lock not acquired)
# Both images should be marked as processed
manager = AsyncJobStateManager(self.job.pk)
progress = manager.get_progress("process")
self.assertEqual(progress.processed, 0)
self.assertIsNotNone(progress)
self.assertEqual(progress.processed, 2)
self.assertEqual(progress.total, 3)
self.assertEqual(mock_manager.acknowledge_task.call_count, 2)
Comment thread
carlosgjs marked this conversation as resolved.

@patch("ami.jobs.tasks.TaskQueueManager")
def test_process_nats_pipeline_result_error_job_not_found(self, mock_manager_class):
Expand Down
183 changes: 84 additions & 99 deletions ami/ml/orchestration/async_job_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
Internal progress tracking for async (NATS) job processing, backed by Redis.

Multiple Celery workers process image batches concurrently and report progress
here using Redis for atomic updates with locking. This module is purely internal
— nothing outside the worker pipeline reads from it directly.
here using Redis native set operations. No locking is required because:

- SREM (remove processed images from pending set) is atomic per call
- SADD (add to failed set) is atomic per call
- SCARD (read set size) is O(1) without deserializing members

Workers update state independently via a single Redis pipeline round-trip.
This module is purely internal — nothing outside the worker pipeline reads
from it directly.

How this relates to the Job model (ami/jobs/models.py):

Expand All @@ -27,7 +34,7 @@
import logging
from dataclasses import dataclass

from django.core.cache import cache
from django_redis import get_redis_connection

logger = logging.getLogger(__name__)

Expand All @@ -50,17 +57,13 @@ class JobStateProgress:
failed: int = 0 # source images that returned an error from the processing service


def _lock_key(job_id: int) -> str:
return f"job:{job_id}:process_results_lock"


class AsyncJobStateManager:
"""
Manages real-time job progress in Redis for concurrent NATS workers.

Each job has per-stage pending image lists and a shared failed image set.
Workers acquire a Redis lock before mutating state, ensuring atomic updates
even when multiple Celery tasks process batches in parallel.
Each job has per-stage pending image sets and a shared failed image set,
all stored as native Redis sets. Workers update state via atomic SREM/SADD
commands — no locking needed.

The results are ephemeral — _update_job_progress() in ami/jobs/tasks.py
copies each snapshot into the persistent Job.progress JSONB field.
Expand All @@ -70,31 +73,32 @@ class AsyncJobStateManager:
STAGES = ["process", "results"]

def __init__(self, job_id: int):
"""
Initialize the task state manager for a specific job.

Args:
job_id: The job primary key
"""
self.job_id = job_id
self._pending_key = f"job:{job_id}:pending_images"
self._total_key = f"job:{job_id}:pending_images_total"
self._failed_key = f"job:{job_id}:failed_images"

def _get_redis(self):
return get_redis_connection("default")
Comment thread
carlosgjs marked this conversation as resolved.

def initialize_job(self, image_ids: list[str]) -> None:
"""
Initialize job tracking with a list of image IDs to process.

Args:
image_ids: List of image IDs that need to be processed
"""
for stage in self.STAGES:
cache.set(self._get_pending_key(stage), image_ids, timeout=self.TIMEOUT)

# Initialize failed images set for process stage only
cache.set(self._failed_key, set(), timeout=self.TIMEOUT)

cache.set(self._total_key, len(image_ids), timeout=self.TIMEOUT)
redis = self._get_redis()
with redis.pipeline() as pipe:
for stage in self.STAGES:
pending_key = self._get_pending_key(stage)
pipe.delete(pending_key)
if image_ids:
pipe.sadd(pending_key, *image_ids)
pipe.expire(pending_key, self.TIMEOUT)
pipe.delete(self._failed_key)
pipe.set(self._total_key, len(image_ids), ex=self.TIMEOUT)
pipe.execute()
Comment thread
carlosgjs marked this conversation as resolved.
Outdated

def _get_pending_key(self, stage: str) -> str:
return f"{self._pending_key}:{stage}"
Expand All @@ -103,100 +107,81 @@ def update_state(
self,
processed_image_ids: set[str],
stage: str,
request_id: str,
failed_image_ids: set[str] | None = None,
) -> None | JobStateProgress:
) -> "JobStateProgress | None":
"""
Update the task state with newly processed images.
Atomically update job state with newly processed images.

Uses a Redis pipeline (single round-trip). SREM and SADD are each
individually atomic; the pipeline batches them with SCARD/GET to avoid
multiple round-trips. Workers can call this concurrently — no lock needed.

Args:
processed_image_ids: Set of image IDs that have just been processed
stage: The processing stage ("process" or "results")
request_id: Unique identifier for this processing request
detections_count: Number of detections to add to cumulative count
classifications_count: Number of classifications to add to cumulative count
captures_count: Number of captures to add to cumulative count
failed_image_ids: Set of image IDs that failed processing (optional)

Returns:
JobStateProgress snapshot, or None if Redis state is missing
(job expired or not yet initialized).
"""
# Create a unique lock key for this job
lock_key = _lock_key(self.job_id)
lock_timeout = 360 # 6 minutes (matches task time_limit)
lock_acquired = cache.add(lock_key, request_id, timeout=lock_timeout)
if not lock_acquired:
redis = self._get_redis()
pending_key = self._get_pending_key(stage)

with redis.pipeline() as pipe:
if processed_image_ids:
pipe.srem(pending_key, *processed_image_ids)
if failed_image_ids:
pipe.sadd(self._failed_key, *failed_image_ids)
pipe.scard(pending_key)
Comment thread
carlosgjs marked this conversation as resolved.
Outdated
pipe.scard(self._failed_key)
pipe.get(self._total_key)
results = pipe.execute()

# Last 3 results are always scard(pending), scard(failed), get(total)
# regardless of whether SREM/SADD appear at the front.
remaining, failed_count, total_raw = results[-3], results[-2], results[-1]

if total_raw is None:
return None

try:
# Update progress tracking in Redis
progress_info = self._commit_update(processed_image_ids, stage, failed_image_ids)
return progress_info
finally:
# Always release the lock when done
current_lock_value = cache.get(lock_key)
# Only delete if we still own the lock (prevents race condition)
if current_lock_value == request_id:
cache.delete(lock_key)
logger.debug(f"Released lock for job {self.job_id}, task {request_id}")

def get_progress(self, stage: str) -> JobStateProgress | None:
"""Read-only progress snapshot for the given stage. Does not acquire a lock or mutate state."""
pending_images = cache.get(self._get_pending_key(stage))
total_images = cache.get(self._total_key)
if pending_images is None or total_images is None:
return None
remaining = len(pending_images)
processed = total_images - remaining
percentage = float(processed) / total_images if total_images > 0 else 1.0
failed_set = cache.get(self._failed_key) or set()
total = int(total_raw)
processed = total - remaining
percentage = float(processed) / total if total > 0 else 1.0

logger.info(
Comment thread
carlosgjs marked this conversation as resolved.
f"Pending images from Redis for job {self.job_id} {stage}: " f"{remaining}/{total}: {percentage*100}%"
)

return JobStateProgress(
remaining=remaining,
total=total_images,
total=total,
processed=processed,
percentage=percentage,
failed=len(failed_set),
failed=failed_count,
)

def _commit_update(
self,
processed_image_ids: set[str],
stage: str,
failed_image_ids: set[str] | None = None,
) -> JobStateProgress | None:
"""
Update pending images and return progress. Must be called under lock.
def get_progress(self, stage: str) -> "JobStateProgress | None":
"""Read-only progress snapshot for the given stage."""
redis = self._get_redis()
pending_key = self._get_pending_key(stage)

Removes processed_image_ids from the pending set and persists the update.
"""
pending_images = cache.get(self._get_pending_key(stage))
total_images = cache.get(self._total_key)
if pending_images is None or total_images is None:
return None
remaining_images = [img_id for img_id in pending_images if img_id not in processed_image_ids]
assert len(pending_images) >= len(remaining_images)
cache.set(self._get_pending_key(stage), remaining_images, timeout=self.TIMEOUT)

remaining = len(remaining_images)
processed = total_images - remaining
percentage = float(processed) / total_images if total_images > 0 else 1.0

# Update failed images set if provided
if failed_image_ids:
existing_failed = cache.get(self._failed_key) or set()
updated_failed = existing_failed | failed_image_ids # Union to prevent duplicates
cache.set(self._failed_key, updated_failed, timeout=self.TIMEOUT)
failed_set = updated_failed
else:
failed_set = cache.get(self._failed_key) or set()
with redis.pipeline() as pipe:
pipe.scard(pending_key)
pipe.scard(self._failed_key)
pipe.get(self._total_key)
remaining, failed_count, total_raw = pipe.execute()

failed_count = len(failed_set)
if total_raw is None:
return None

logger.info(
f"Pending images from Redis for job {self.job_id} {stage}: "
f"{remaining}/{total_images}: {percentage*100}%"
)
total = int(total_raw)
processed = total - remaining
percentage = float(processed) / total if total > 0 else 1.0

return JobStateProgress(
remaining=remaining,
total=total_images,
total=total,
processed=processed,
percentage=percentage,
failed=failed_count,
Expand All @@ -206,7 +191,7 @@ def cleanup(self) -> None:
"""
Delete all Redis keys associated with this job.
"""
for stage in self.STAGES:
cache.delete(self._get_pending_key(stage))
cache.delete(self._failed_key)
cache.delete(self._total_key)
redis = self._get_redis()
keys = [self._get_pending_key(stage) for stage in self.STAGES]
keys += [self._failed_key, self._total_key]
redis.delete(*keys)
2 changes: 1 addition & 1 deletion ami/ml/orchestration/nats_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class TaskQueueManager:
def __init__(self, nats_url: str | None = None, max_ack_pending: int | None = None):
self.nats_url = nats_url or getattr(settings, "NATS_URL", "nats://nats:4222")
self.max_ack_pending = (
max_ack_pending if max_ack_pending is not None else getattr(settings, "NATS_MAX_ACK_PENDING", 100)
max_ack_pending if max_ack_pending is not None else getattr(settings, "NATS_MAX_ACK_PENDING", 1000)
)
Comment thread
carlosgjs marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
self.nc: nats.NATS | None = None
self.js: JetStreamContext | None = None
Expand Down
Loading