Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 83 additions & 19 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,12 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
classifications_count = sum(len(detection.classifications) for detection in pipeline_result.detections)
captures_count = len(pipeline_result.source_images)

acked = _ack_task_via_nats(reply_subject, job.logger)
# Update job stage with calculated progress

# Do NOT ack NATS yet. ACK must happen AFTER the results-stage SREM and
# _update_job_progress so that a worker crash between save_results and
# progress commit leaves the message redeliverable. Previously the ACK
# ran here (before SREM): on crash, NATS drained permanently while
# Redis pending_images:results kept the id, stranding the job at
# partial progress with no path to completion. See antenna#1232.
try:
progress_info = state_manager.update_state(
processed_image_ids,
Expand All @@ -187,19 +190,20 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
except RedisError as e:
# Transient. save_results dedupes on re-run (get_or_create_detection)
# and SREM is a no-op on already-removed ids, so a Celery retry is
# safe for the DB and Redis sets. The caveat is _update_job_progress
# accumulates detections/classifications/captures on the results
# stage (see _update_job_progress stage=="results" branch); if this
# retry runs a second time (or NATS redelivers to ADC because
# ack_wait elapsed before we got here), those counters will inflate
# cosmetically. Tracked in #1232.
# safe for the DB and Redis sets. Counter accumulation is gated on
# progress_info.newly_removed below, so replays will not inflate
# detections/classifications/captures (fixes antenna#1232 replay case).
job.logger.warning(
f"Transient Redis error updating job {job_id} state (stage=results); Celery will retry: {e}",
exc_info=True,
)
raise

if not progress_info:
# State keys genuinely missing (total-images key returned None). Ack
# first so NATS stops redelivering a message whose state is gone,
# then fail the job. Mirrors the stage=process missing-state path.
_ack_task_via_nats(reply_subject, job.logger)
_fail_job(job_id, "Job state keys not found in Redis (likely cleaned up concurrently)")
return

Expand All @@ -208,15 +212,40 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
if progress_info.total > 0 and (progress_info.failed / progress_info.total) > FAILURE_THRESHOLD:
complete_state = JobState.FAILURE

_update_job_progress(
job_id,
"results",
progress_info.percentage,
complete_state=complete_state,
detections=detections_count,
classifications=classifications_count,
captures=captures_count,
)
# Counter-inflation guard: only add detection/classification/capture counts
# when SREM actually removed ids (first processing of this result). On a
# replay (NATS redelivered the message or the Celery task retried past
# the SREM), newly_removed==0 and we skip accumulation to keep the
# counters idempotent. The percentage/status path still runs because
# _update_job_progress uses max() and preserves FAILURE regardless.
if progress_info.newly_removed > 0:
_update_job_progress(
job_id,
"results",
progress_info.percentage,
complete_state=complete_state,
detections=detections_count,
classifications=classifications_count,
captures=captures_count,
)
else:
_update_job_progress(
job_id,
"results",
progress_info.percentage,
complete_state=complete_state,
detections=0,
classifications=0,
captures=0,
)

# Ack LAST — only after the results-stage SREM and progress commit are
# durable. If anything above crashes, NATS will redeliver the message
# and the full result path re-runs idempotently: save_results dedupes
# on (detection, source_image), SREM is a no-op on already-removed ids
# (newly_removed==0 gates counter accumulation), and the progress
# percentage is clamped by max() to never regress.
acked = _ack_task_via_nats(reply_subject, job.logger)

except RedisError:
# Logged above at the specific update_state call site; re-raise so
Expand Down Expand Up @@ -337,8 +366,11 @@ def _update_job_progress(
# Don't overwrite a stage with a stale progress value.
# This guards against the race where a slower worker calls _update_job_progress
# after a faster worker has already marked further progress.
passed_progress = progress_percentage
existing_progress: float | None = None
try:
existing_stage = job.progress.get_stage(stage)
existing_progress = existing_stage.progress
progress_percentage = max(existing_stage.progress, progress_percentage)
# Explicitly preserve FAILURE: once a stage is marked FAILURE it should
# never regress to a non-failure state, regardless of enum ordering.
Expand All @@ -347,6 +379,17 @@ def _update_job_progress(
except (ValueError, AttributeError):
pass # Stage doesn't exist yet; proceed normally

# Diagnostic: when max() lifts the percentage to 1.0 from a partial value
# this worker computed, surface it. A legitimate jump means another
# worker concurrently completed the stage; an unexpected jump (e.g. the
# premature-cleanup pattern from antenna#????) is otherwise invisible.
Comment thread
mihow marked this conversation as resolved.
Outdated
if existing_progress is not None and progress_percentage >= 1.0 and passed_progress < 1.0:
job.logger.warning(
f"Stage '{stage}' progress lifted to 100% by max() guard: "
f"this worker passed {passed_progress*100:.1f}%, DB had {existing_progress*100:.1f}%. "
f"If no other worker just legitimately finished this stage, this is a state-race symptom."
)

# Determine the status to write:
# - Stage complete (100%): use complete_state (SUCCESS or FAILURE)
# - Stage incomplete but FAILURE already determined: keep FAILURE visible
Expand Down Expand Up @@ -374,6 +417,11 @@ def _update_job_progress(
# Clean up async resources for completed jobs that use NATS/Redis
if job.progress.is_complete():
job = Job.objects.get(pk=job_id) # Re-fetch outside transaction
# Diagnostic: log which stages satisfied the complete condition. Without
# this, premature-cleanup bugs (cleanup fires while results are still
# mid-flight) are hard to trace back to a specific stage transition.
stages_summary = ", ".join(f"{s.key}={s.progress*100:.1f}% {s.status}" for s in job.progress.stages)
job.logger.info(f"is_complete()=True after stage='{stage}' update; firing cleanup. Stages: {stages_summary}")
cleanup_async_job_if_needed(job)


Expand Down Expand Up @@ -659,9 +707,25 @@ def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs):

@task_failure.connect(sender=run_job, retry=False)
def update_job_failure(sender, task_id, exception, *args, **kwargs):
from ami.jobs.models import Job, JobState
from ami.jobs.models import Job, JobDispatchMode, JobState

job = Job.objects.get(task_id=task_id)

# For ASYNC_API jobs where images have been queued to NATS but the final
# stages have not completed, a run_job failure (e.g. a transient exception
# raised *after* queue_images_to_nats returned) would otherwise collapse an
# otherwise-healthy async job: NATS workers are still processing, results
# are still arriving, but this handler would mark FAILURE and cleanup would
# destroy the stream/consumer + Redis state mid-flight. Defer terminal
# state to the async result handler, which owns is_complete() transitions.
# Mirrors the SUCCESS guard in update_job_status (task_postrun).
if job.dispatch_mode == JobDispatchMode.ASYNC_API and not job.progress.is_complete():
job.logger.warning(
f'Job #{job.pk} "{job.name}" run_job raised but async processing is in-flight; '
f"deferring FAILURE to async progress handler: {exception}"
)
return

job.update_status(JobState.FAILURE, save=False)

job.logger.error(f'Job #{job.pk} "{job.name}" failed: {exception}')
Expand Down
132 changes: 132 additions & 0 deletions ami/jobs/tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,138 @@ def test_genuinely_missing_state_acks_and_fails_job(self, mock_manager_class, mo
args, _ = mock_fail.call_args
self.assertIn("Job state keys not found in Redis", args[1])

@patch("ami.jobs.tasks._ack_task_via_nats")
@patch("ami.jobs.tasks.TaskQueueManager")
def test_ack_deferred_until_after_results_stage_srem(self, mock_manager_class, mock_ack):
"""
Bug A regression: NATS ACK must NOT happen until after the results-stage
SREM is durable in Redis. A worker crash between save_results and the
results SREM would otherwise strand the image in pending_images:results
with NATS already drained (no redelivery) — the job's results stage
never reaches 100% and no code path reconciles it.

This test simulates a crash on the results-stage SREM. Correct behavior:
- process-stage SREM succeeded (called first, no crash)
- save_results ran
- results-stage SREM raised RedisError → exception propagates to Celery
- ACK was NOT called (so NATS will redeliver after ack_wait)

On buggy code (ACK before results SREM), mock_ack would be called before
the raise, leaving the id stranded in Redis.
"""
from redis.exceptions import RedisError

self._setup_mock_nats(mock_manager_class)

# save_results requires the pipeline to have at least one detection
# algorithm. Attach a minimal one so we exercise the full save_results
# path before hitting the results-stage SREM we're testing.
detection_algorithm = Algorithm.objects.create(
name="ack-ordering-detector",
key="ack-ordering-detector",
task_type=AlgorithmTaskType.LOCALIZATION,
)
self.pipeline.algorithms.add(detection_algorithm)

# Use a success result (not an error) so save_results path runs fully.
# An empty detections list keeps save_results cheap.
success_data = PipelineResultsResponse(
pipeline="test-pipeline",
algorithms={},
total_time=1.0,
source_images=[SourceImageResponse(id=str(self.images[0].pk), url="http://example.com/test_image_0.jpg")],
detections=[],
errors=None,
).dict()

real_update_state = AsyncJobStateManager.update_state

def fail_on_results_stage(self, processed_image_ids, stage, failed_image_ids=None):
if stage == "results":
raise RedisError("connection reset on results SREM")
return real_update_state(self, processed_image_ids, stage, failed_image_ids)

with patch.object(AsyncJobStateManager, "update_state", fail_on_results_stage):
with self.assertRaises(RedisError):
process_nats_pipeline_result(
job_id=self.job.pk,
result_data=success_data,
reply_subject="reply.ack-ordering",
)

mock_ack.assert_not_called()

# Process stage SREM ran and removed the id; results stage still holds it,
# waiting for a successful retry or NATS redelivery.
process_progress = AsyncJobStateManager(self.job.pk).get_progress("process")
results_progress = AsyncJobStateManager(self.job.pk).get_progress("results")
self.assertEqual(process_progress.processed, 1)
self.assertEqual(results_progress.processed, 0)

@patch("ami.jobs.tasks.TaskQueueManager")
def test_results_counter_does_not_inflate_on_replay(self, mock_manager_class):
"""
Bug A companion (antenna#1232): _update_job_progress("results") accumulates
detections/classifications/captures by reading existing values and adding
new ones — not idempotent. On a NATS redelivery or Celery retry, the same
batch can legitimately arrive twice. The fix gates accumulation on
update_state's newly_removed (SREM's integer return, 0 on replay).

Scenario: deliver the same result twice. Counters should reflect one
batch, not two.
"""
self._setup_mock_nats(mock_manager_class)

detection_algorithm = Algorithm.objects.create(
name="replay-detector",
key="replay-detector",
task_type=AlgorithmTaskType.LOCALIZATION,
)
self.pipeline.algorithms.add(detection_algorithm)

# Empty-detections success keeps save_results cheap; the counter
# accumulation still runs because captures_count = len(source_images) = 1.
success_data = PipelineResultsResponse(
pipeline="test-pipeline",
algorithms={},
total_time=1.0,
source_images=[SourceImageResponse(id=str(self.images[0].pk), url="http://example.com/test_image_0.jpg")],
detections=[],
errors=None,
).dict()

# First delivery: counters should advance by 1 capture.
process_nats_pipeline_result.apply(
kwargs={"job_id": self.job.pk, "result_data": success_data, "reply_subject": "reply.first"}
)

self.job.refresh_from_db()
results_stage = next(s for s in self.job.progress.stages if s.key == "results")
captures_after_first = next(
(p.value for p in results_stage.params if p.key == "captures"),
0,
)
self.assertEqual(captures_after_first, 1, "first delivery should count 1 capture")

# Second delivery of the same result (NATS redeliver / Celery retry after
# the results SREM was already durable). SREM now returns 0 (id already
# gone). Counters must NOT double.
process_nats_pipeline_result.apply(
kwargs={"job_id": self.job.pk, "result_data": success_data, "reply_subject": "reply.replay"}
)

self.job.refresh_from_db()
results_stage = next(s for s in self.job.progress.stages if s.key == "results")
captures_after_replay = next(
(p.value for p in results_stage.params if p.key == "captures"),
0,
)
self.assertEqual(
captures_after_replay,
1,
f"replay must not inflate captures counter (got {captures_after_replay}, expected 1)",
)

@patch("ami.jobs.tasks.TaskQueueManager")
def test_process_nats_pipeline_result_error_job_not_found(self, mock_manager_class):
"""
Expand Down
7 changes: 7 additions & 0 deletions ami/ml/orchestration/async_job_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class JobStateProgress:
processed: int = 0 # source images completed (success + failed)
percentage: float = 0.0 # processed / total
failed: int = 0 # source images that returned an error from the processing service
newly_removed: int = 0 # number of IDs actually removed by this SREM call (0 on replay)


class AsyncJobStateManager:
Expand Down Expand Up @@ -156,6 +157,11 @@ def update_state(
# regardless of whether SREM/SADD appear at the front.
remaining, failed_count, total_raw = results[-3], results[-2], results[-1]

# SREM's integer return (number of members actually removed) is at results[0]
# when processed_image_ids is non-empty. Zero on a replay because the IDs are
# no longer in the set. Used by callers to gate idempotent counter accumulation.
newly_removed = results[0] if processed_image_ids else 0

if total_raw is None:
return None

Expand All @@ -173,6 +179,7 @@ def update_state(
processed=processed,
percentage=percentage,
failed=failed_count,
newly_removed=newly_removed,
)

def get_progress(self, stage: str) -> "JobStateProgress | None":
Expand Down
10 changes: 9 additions & 1 deletion ami/ml/orchestration/nats_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ async def get_connection(nats_url: str) -> tuple[nats.NATS, JetStreamContext]:


TASK_TTR = getattr(settings, "NATS_TASK_TTR", 30) # Visibility timeout in seconds (configurable)

# Max delivery attempts per NATS message (1 original + N-1 retries).
# A processing service that consistently fails (e.g. returns results referencing
# an algorithm that the pipeline doesn't declare) will burn ADC + worker time on
# every retry; one retry covers a transient blip and is the right tradeoff.
# Hoist to settings (NATS_MAX_DELIVER) when we need per-environment tuning.
Comment thread
mihow marked this conversation as resolved.
Outdated
NATS_MAX_DELIVER = getattr(settings, "NATS_MAX_DELIVER", 2)

ADVISORY_STREAM_NAME = "advisories" # Shared stream for max delivery advisories across all jobs


Expand Down Expand Up @@ -342,7 +350,7 @@ async def _ensure_consumer(self, job_id: int):
durable_name=consumer_name,
ack_policy=AckPolicy.EXPLICIT,
ack_wait=TASK_TTR, # Visibility timeout (TTR)
max_deliver=5, # Max retry attempts
max_deliver=NATS_MAX_DELIVER,
deliver_policy=DeliverPolicy.ALL,
max_ack_pending=self.max_ack_pending,
filter_subject=subject,
Expand Down
Loading
Loading