Skip to content
Closed
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8df89be
Avoid redis based locking by using atomic updates
carlosgjs Feb 24, 2026
1096fd9
Merge branch 'main' into carlosg/redisatomic
carlosgjs Feb 24, 2026
30c8db3
Test concurrency
carlosgjs Feb 25, 2026
deea095
Increase max ack pending
carlosgjs Feb 25, 2026
20c0fbd
update comment
carlosgjs Feb 25, 2026
e84421e
CR feedback
carlosgjs Feb 25, 2026
cbb2d7f
Cancel jobs if Redis state is missing
carlosgjs Feb 25, 2026
3861190
Add chaos monkey
carlosgjs Feb 25, 2026
d591bd6
CR feedback
carlosgjs Feb 25, 2026
4720bb6
CR 2
carlosgjs Feb 26, 2026
f0cd403
fix: OrderedEnum comparisons now override str MRO in subclasses
mihow Feb 26, 2026
e3134a1
fix: correct misleading error log about NATS redelivery
mihow Feb 26, 2026
41b1232
Merge branch 'carlosg/redisatomic' of github.com:uw-ssec/antenna into…
carlosgjs Feb 26, 2026
94e1bbb
Use job.logger
carlosgjs Feb 26, 2026
dcf57fe
Use job.logger
carlosgjs Feb 26, 2026
4a25e54
Integrate cancellation support
carlosgjs Feb 26, 2026
654593b
Merge branch 'carlosg/redisatomic' into carlos/redisfail
carlosgjs Feb 26, 2026
5d38d67
merge, update tests
carlosgjs Feb 26, 2026
ac90c2f
Remove pause support in monkey
carlosgjs Feb 26, 2026
4eb763a
fix: cancel async jobs by cleaning up NATS/Redis and stopping task de…
mihow Feb 27, 2026
8671214
fix(ui): hide Retry button while job is in CANCELING state
mihow Feb 27, 2026
b1146cc
fix: downgrade Redis-missing log to warning for canceled jobs
mihow Feb 27, 2026
dccaceb
docs: add async job monitoring reference
mihow Feb 27, 2026
d63be48
fix: update tests for active_states() guard on /tasks endpoint
mihow Feb 27, 2026
934db1d
Merge branch 'main' into carlos/redisfail
carlosgjs Feb 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .agents/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,16 @@ images = SourceImage.objects.annotate(det_count=Count('detections'))
- Use `@shared_task` decorator for all tasks
- Check Flower UI for debugging: http://localhost:5555

### E2E Testing & Monitoring Async Jobs

Run an end-to-end ML job test:
```bash
docker compose run --rm django python manage.py test_ml_job_e2e \
--project 18 --dispatch-mode async_api --collection 142 --pipeline "global_moths_2024"
```

For monitoring running jobs (Django ORM, REST API, NATS consumer state, Redis counters, worker logs, etc.), see `docs/claude/reference/monitoring-async-jobs.md`.

### Running a Single Test

```bash
Expand Down
96 changes: 96 additions & 0 deletions ami/jobs/management/commands/chaos_monkey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Fault injection utility for manual chaos testing of ML async jobs.

Use alongside `test_ml_job_e2e` to verify job behaviour when Redis or NATS
becomes unavailable or loses state mid-processing.

Usage examples:

# Flush all Redis state immediately (simulates FLUSHDB mid-job)
python manage.py chaos_monkey flush redis

# Flush all NATS JetStream streams (simulates broker state loss)
python manage.py chaos_monkey flush nats
"""

from asgiref.sync import async_to_sync
from django.core.management.base import BaseCommand, CommandError
from django_redis import get_redis_connection

NATS_URL = "nats://ami_local_nats:4222"


class Command(BaseCommand):
help = "Inject faults into Redis or NATS for chaos/resilience testing"

def add_arguments(self, parser):
parser.add_argument(
"action",
choices=["flush"],
help="flush: wipe all state.",
)
parser.add_argument(
"service",
choices=["redis", "nats"],
help="Target service to fault.",
)

def handle(self, *args, **options):
action = options["action"]
service = options["service"]

if action == "flush" and service == "redis":
self._flush_redis()
elif action == "flush" and service == "nats":
self._flush_nats()

# ------------------------------------------------------------------
# Redis
# ------------------------------------------------------------------

def _flush_redis(self):
self.stdout.write("Flushing Redis database (FLUSHDB)...")
try:
redis = get_redis_connection("default")
redis.flushdb()
self.stdout.write(self.style.SUCCESS("Redis flushed."))
except Exception as e:
raise CommandError(f"Failed to flush Redis: {e}") from e

# ------------------------------------------------------------------
# NATS
# ------------------------------------------------------------------

def _flush_nats(self):
"""Delete all JetStream streams via the NATS Python client."""
self.stdout.write("Flushing all NATS JetStream streams...")

async def _delete_all_streams():
import nats

nc = await nats.connect(NATS_URL, connect_timeout=5, allow_reconnect=False)
js = nc.jetstream()
try:
streams = await js.streams_info()
if not streams:
return []
deleted = []
for stream in streams:
name = stream.config.name
await js.delete_stream(name)
deleted.append(name)
return deleted
finally:
await nc.close()

try:
deleted = async_to_sync(_delete_all_streams)()
except Exception as e:
raise CommandError(f"Failed to flush NATS: {e}") from e

if deleted:
for name in deleted:
self.stdout.write(f" Deleted stream: {name}")
self.stdout.write(self.style.SUCCESS(f"Deleted {len(deleted)} stream(s)."))
else:
self.stdout.write("No streams found — NATS already empty.")
6 changes: 5 additions & 1 deletion ami/jobs/management/commands/test_ml_job_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@


class Command(BaseCommand):
help = "Run end-to-end test of ML job processing"
help = (
"Run end-to-end test of ML job processing.\n\n"
"For monitoring and debugging running jobs, see:\n"
" docs/claude/reference/monitoring-async-jobs.md"
)

def add_arguments(self, parser):
parser.add_argument("--project", type=int, required=True, help="Project ID")
Expand Down
35 changes: 28 additions & 7 deletions ami/jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ami.base.models import BaseModel
from ami.base.schemas import ConfigurableStage, ConfigurableStageParam
from ami.jobs.tasks import run_job
from ami.jobs.tasks import cleanup_async_job_if_needed, run_job
from ami.main.models import Deployment, Project, SourceImage, SourceImageCollection
from ami.ml.models import Pipeline
from ami.ml.post_processing.registry import get_postprocessing_task
Expand Down Expand Up @@ -88,6 +88,11 @@ def final_states(cls):
def failed_states(cls):
return [cls.FAILURE, cls.REVOKED, cls.UNKNOWN]

@classmethod
def active_states(cls):
"""States where a job is actively processing and should serve tasks to workers."""
return [cls.STARTED, cls.RETRY]


def get_status_label(status: JobState, progress: float) -> str:
"""
Expand Down Expand Up @@ -331,7 +336,11 @@ def emit(self, record: logging.LogRecord):
# Log to the current app logger
logger.log(record.levelno, self.format(record))

# Write to the logs field on the job instance
# Write to the logs field on the job instance.
# Refresh from DB first to reduce the window for concurrent overwrites — each
# worker holds its own stale in-memory copy of `logs`, so without a refresh the
# last writer always wins and earlier entries are silently dropped.
self.job.refresh_from_db(fields=["logs"])
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
msg = f"[{timestamp}] {record.levelname} {self.format(record)}"
if msg not in self.job.logs.stdout:
Expand All @@ -350,7 +359,6 @@ def emit(self, record: logging.LogRecord):
self.job.save(update_fields=["logs"], update_progress=False)
except Exception as e:
logger.error(f"Failed to save logs for job #{self.job.pk}: {e}")
pass


@dataclass
Expand Down Expand Up @@ -966,15 +974,24 @@ def retry(self, async_task=True):

def cancel(self):
"""
Terminate the celery task.
Cancel a job. For async_api jobs, clean up NATS/Redis resources
and transition through CANCELING → REVOKED. For other jobs,
revoke the Celery task.
"""
self.status = JobState.CANCELING
self.save()

cleanup_async_job_if_needed(self)
if self.task_id:
task = run_job.AsyncResult(self.task_id)
if task:
task.revoke(terminate=True)
self.save()
if self.dispatch_mode == JobDispatchMode.ASYNC_API:
# For async jobs we need to set the status to revoked here since the task already
# finished (it only queues the images).
self.status = JobState.REVOKED
self.save()
else:
self.status = JobState.REVOKED
self.save()
Expand Down Expand Up @@ -1084,11 +1101,15 @@ def get_default_progress(cls) -> JobProgress:
def logger(self) -> logging.Logger:
_logger = logging.getLogger(f"ami.jobs.{self.pk}")

# Only add JobLogHandler if not already present
if not any(isinstance(h, JobLogHandler) for h in _logger.handlers):
# Also log output to a field on thie model instance
# Update or add JobLogHandler, always pointing to the current instance.
# The logger is a process-level singleton so its handler may reference a stale
# job instance from a previous task execution in this worker process.
handler = next((h for h in _logger.handlers if isinstance(h, JobLogHandler)), None)
if handler is None:
logger.info("Adding JobLogHandler to logger for job %s", self.pk)
_logger.addHandler(JobLogHandler(self))
else:
handler.job = self
_logger.propagate = False
return _logger

Expand Down
87 changes: 64 additions & 23 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,12 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub

state_manager = AsyncJobStateManager(job_id)

progress_info = state_manager.update_state(
processed_image_ids, stage="process", request_id=self.request.id, failed_image_ids=failed_image_ids
)
progress_info = state_manager.update_state(processed_image_ids, stage="process", failed_image_ids=failed_image_ids)
if not progress_info:
logger.warning(
f"Another task is already processing results for job {job_id}. "
f"Retrying task {self.request.id} in 5 seconds..."
)
raise self.retry(countdown=5, max_retries=10)
logger.warning(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.")
# Acknowledge the task to prevent retries, since we don't know the state
_ack_task_via_nats(reply_subject, logger)
return

try:
complete_state = JobState.SUCCESS
Expand Down Expand Up @@ -126,6 +123,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
_ack_task_via_nats(reply_subject, logger)
return

acked = False
try:
# Save to database (this is the slow operation)
detections_count, classifications_count, captures_count = 0, 0, 0
Expand All @@ -145,20 +143,17 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
captures_count = len(pipeline_result.source_images)

_ack_task_via_nats(reply_subject, job.logger)
acked = True
# Update job stage with calculated progress

progress_info = state_manager.update_state(
processed_image_ids,
stage="results",
request_id=self.request.id,
)

if not progress_info:
logger.warning(
f"Another task is already processing results for job {job_id}. "
f"Retrying task {self.request.id} in 5 seconds..."
)
raise self.retry(countdown=5, max_retries=10)
job.logger.warning(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.")
return

# update complete state based on latest progress info after saving results
complete_state = JobState.SUCCESS
Expand All @@ -176,9 +171,31 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
)

except Exception as e:
job.logger.error(
f"Failed to process pipeline result for job {job_id}: {e}. NATS will redeliver the task message."
)
error = f"Error processing pipeline result for job {job_id}: {e}"
if not acked:
error += ". NATS will re-deliver the task message."

job.logger.error(error)


def _fail_job(job_id: int, reason: str) -> None:
from ami.jobs.models import Job, JobState
from ami.ml.orchestration.jobs import cleanup_async_job_resources

try:
with transaction.atomic():
job = Job.objects.select_for_update().get(pk=job_id)
if job.status in (JobState.CANCELING, *JobState.final_states()):
return
job.status = JobState.FAILURE
job.finished_at = datetime.datetime.now()
job.save(update_fields=["status", "finished_at"])

job.logger.error(f"Job {job_id} marked as FAILURE: {reason}")
cleanup_async_job_resources(job.pk, job.logger)
except Job.DoesNotExist:
logger.error(f"Cannot fail job {job_id}: not found")
cleanup_async_job_resources(job_id, logger)


def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None:
Expand Down Expand Up @@ -256,9 +273,33 @@ def _update_job_progress(
state_params["classifications"] = current_classifications + new_classifications
state_params["captures"] = current_captures + new_captures

# Don't overwrite a stage with a stale progress value.
# This guards against the race where a slower worker calls _update_job_progress
# after a faster worker has already marked further progress.
try:
existing_stage = job.progress.get_stage(stage)
progress_percentage = max(existing_stage.progress, progress_percentage)
# Explicitly preserve FAILURE: once a stage is marked FAILURE it should
# never regress to a non-failure state, regardless of enum ordering.
if existing_stage.status == JobState.FAILURE:
complete_state = JobState.FAILURE
except (ValueError, AttributeError):
pass # Stage doesn't exist yet; proceed normally

# Determine the status to write:
# - Stage complete (100%): use complete_state (SUCCESS or FAILURE)
# - Stage incomplete but FAILURE already determined: keep FAILURE visible
# - Stage incomplete, no failure: mark as in-progress (STARTED)
if progress_percentage >= 1.0:
status = complete_state
elif complete_state == JobState.FAILURE:
status = JobState.FAILURE
else:
status = JobState.STARTED

job.progress.update_stage(
stage,
status=complete_state if progress_percentage >= 1.0 else JobState.STARTED,
status=status,
progress=progress_percentage,
**state_params,
)
Expand All @@ -272,10 +313,10 @@ def _update_job_progress(
# Clean up async resources for completed jobs that use NATS/Redis
if job.progress.is_complete():
job = Job.objects.get(pk=job_id) # Re-fetch outside transaction
_cleanup_job_if_needed(job)
cleanup_async_job_if_needed(job)


def _cleanup_job_if_needed(job) -> None:
def cleanup_async_job_if_needed(job) -> None:
"""
Clean up async resources (NATS/Redis) if this job uses them.

Expand All @@ -291,7 +332,7 @@ def _cleanup_job_if_needed(job) -> None:
# import here to avoid circular imports
from ami.ml.orchestration.jobs import cleanup_async_job_resources

cleanup_async_job_resources(job)
cleanup_async_job_resources(job.pk, job.logger)


@task_prerun.connect(sender=run_job)
Expand Down Expand Up @@ -330,7 +371,7 @@ def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs):

# Clean up async resources for revoked jobs
if state == JobState.REVOKED:
_cleanup_job_if_needed(job)
cleanup_async_job_if_needed(job)


@task_failure.connect(sender=run_job, retry=False)
Expand All @@ -345,7 +386,7 @@ def update_job_failure(sender, task_id, exception, *args, **kwargs):
job.save()

# Clean up async resources for failed jobs
_cleanup_job_if_needed(job)
cleanup_async_job_if_needed(job)


def log_time(start: float = 0, msg: str | None = None) -> tuple[float, Callable]:
Expand Down
Loading