Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions ami/jobs/management/commands/update_stale_jobs.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,37 @@
from celery import states
from celery.result import AsyncResult
from django.core.management.base import BaseCommand
from django.utils import timezone

from ami.jobs.models import Job, JobState
from ami.jobs.models import Job
from ami.jobs.tasks import check_stale_jobs


class Command(BaseCommand):
help = (
"Update the status of all jobs that are not in a final state " "and have not been updated in the last X hours."
)
help = "Revoke stale jobs that have not been updated within the cutoff period."

# Add argument for the number of hours to consider a job stale
def add_arguments(self, parser):
parser.add_argument(
"--hours",
type=int,
default=Job.FAILED_CUTOFF_HOURS,
help="Number of hours to consider a job stale",
help="Number of hours to consider a job stale (default: %(default)s)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be done without making changes",
)

def handle(self, *args, **options):
stale_jobs = Job.objects.filter(
status__in=JobState.running_states(),
updated_at__lt=timezone.now() - timezone.timedelta(hours=options["hours"]),
)
results = check_stale_jobs(hours=options["hours"], dry_run=options["dry_run"])

if not results:
self.stdout.write("No stale jobs found.")
return

for job in stale_jobs:
task = AsyncResult(job.task_id) if job.task_id else None
if task:
job.update_status(task.state, save=False)
job.save()
self.stdout.write(self.style.SUCCESS(f"Updated status of job {job.pk} to {task.state}"))
prefix = "[dry-run] " if options["dry_run"] else ""
for r in results:
if r["action"] == "updated":
self.stdout.write(
self.style.SUCCESS(f"{prefix}Job {r['job_id']}: updated to {r['state']} (from Celery)")
)
else:
self.stdout.write(self.style.WARNING(f"Job {job.pk} has no associated task, setting status to FAILED"))
job.update_status(states.FAILURE, save=False)
job.save()
self.stdout.write(self.style.WARNING(f"{prefix}Job {r['job_id']}: revoked (no known Celery state)"))
91 changes: 91 additions & 0 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,97 @@ def _update_job_progress(
cleanup_async_job_if_needed(job)


def check_stale_jobs(hours: int | None = None, dry_run: bool = False) -> list[dict]:
"""
Find jobs stuck in a running state past the cutoff and revoke them.

For each stale job, checks Celery for a terminal task status. REVOKED is
always trusted. For async_api jobs, SUCCESS and FAILURE are only accepted
when job.progress.is_complete() — NATS workers may still be delivering
results after the Celery task finishes. All other cases result in revocation.
Async resources (NATS/Redis) are cleaned up in both branches.

Returns a list of dicts describing what was done to each job.
"""
import datetime

from celery import states
from celery.result import AsyncResult
from django.db import transaction

from ami.jobs.models import Job, JobDispatchMode, JobState

if hours is None:
hours = Job.FAILED_CUTOFF_HOURS

cutoff = datetime.datetime.now() - datetime.timedelta(hours=hours)
stale_pks = list(
Job.objects.filter(
status__in=JobState.running_states(),
updated_at__lt=cutoff,
).values_list("pk", flat=True)
)

results = []
for pk in stale_pks:
with transaction.atomic():
try:
job = Job.objects.select_for_update().get(
pk=pk,
status__in=JobState.running_states(),
updated_at__lt=cutoff,
)
except Job.DoesNotExist:
# Another concurrent run already handled this job.
continue

celery_state = None
if job.task_id:
try:
celery_state = AsyncResult(job.task_id).state
except Exception:
logger.warning(
"Failed to fetch Celery state for stale job %s (task_id=%s)",
job.pk,
job.task_id,
exc_info=True,
)
# Treat as unknown state — job will be revoked below.

# Only trust terminal Celery states. For async_api jobs, SUCCESS and
# FAILURE are only accepted when progress is complete — NATS workers may
# still be delivering results after the Celery task finishes.
is_terminal = celery_state in states.READY_STATES
is_async_api = job.dispatch_mode == JobDispatchMode.ASYNC_API
if is_async_api and celery_state in {states.SUCCESS, states.FAILURE} and not job.progress.is_complete():
is_terminal = False

previous_status = job.status
if is_terminal:
if not dry_run:
job.update_status(celery_state, save=False)
job.finished_at = datetime.datetime.now()
job.save()
else:
if not dry_run:
job.update_status(JobState.REVOKED, save=False)
job.finished_at = datetime.datetime.now()
job.save()

# Async resource cleanup runs outside the transaction — it makes network
# calls (NATS/Redis) that should not hold the DB row lock.
if not dry_run:
job.refresh_from_db()
cleanup_async_job_if_needed(job)

if is_terminal:
results.append({"job_id": job.pk, "action": "updated", "state": celery_state})
else:
results.append({"job_id": job.pk, "action": "revoked", "previous_status": previous_status})

return results


def cleanup_async_job_if_needed(job) -> None:
"""
Clean up async resources (NATS/Redis) if this job uses them.
Expand Down
Empty file added ami/jobs/tests/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
123 changes: 123 additions & 0 deletions ami/jobs/tests/test_update_stale_jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from datetime import timedelta
from unittest.mock import patch

from django.test import TestCase
from django.utils import timezone

from ami.jobs.models import Job, JobDispatchMode, JobState
from ami.jobs.tasks import check_stale_jobs
from ami.main.models import Project


class CheckStaleJobsTest(TestCase):
Comment thread
mihow marked this conversation as resolved.
def setUp(self):
self.project = Project.objects.create(name="Stale jobs test project")

def _create_job(self, status=JobState.STARTED, hours_ago=100, task_id=None):
job = Job.objects.create(
project=self.project,
name=f"Test job {status}",
status=status,
)
Job.objects.filter(pk=job.pk).update(
updated_at=timezone.now() - timedelta(hours=hours_ago),
)
if task_id is not None:
Job.objects.filter(pk=job.pk).update(task_id=task_id)
job.refresh_from_db()
return job

@patch("ami.jobs.tasks.cleanup_async_job_if_needed")
def test_dry_run(self, mock_cleanup):
"""Dry run returns results without modifying jobs."""
job = self._create_job(status=JobState.STARTED)

results = check_stale_jobs(dry_run=True)

self.assertEqual(len(results), 1)
self.assertEqual(results[0]["action"], "revoked")
job.refresh_from_db()
self.assertEqual(job.status, JobState.STARTED.value)
mock_cleanup.assert_not_called()

@patch("ami.jobs.tasks.cleanup_async_job_if_needed")
def test_revokes_stale_job(self, mock_cleanup):
"""Stale job without a known Celery state is revoked and cleaned up."""
job = self._create_job(status=JobState.STARTED)

results = check_stale_jobs()

self.assertEqual(len(results), 1)
result = results[0]
self.assertEqual(result["action"], "revoked")
self.assertEqual(result["previous_status"], JobState.STARTED)
job.refresh_from_db()
self.assertEqual(job.status, JobState.REVOKED.value)
self.assertIsNotNone(job.finished_at)
mock_cleanup.assert_called_once_with(job)

@patch("ami.jobs.tasks.cleanup_async_job_if_needed")
Comment thread
mihow marked this conversation as resolved.
@patch("celery.result.AsyncResult")
def test_updates_status_from_known_celery_state(self, mock_async_result, mock_cleanup):
"""Stale job with a terminal Celery state is updated (not revoked)."""
from celery import states

mock_async_result.return_value.state = states.FAILURE
job = self._create_job(status=JobState.STARTED, task_id="some-celery-task-id")

results = check_stale_jobs()

self.assertEqual(len(results), 1)
result = results[0]
self.assertEqual(result["action"], "updated")
self.assertEqual(result["state"], states.FAILURE)
job.refresh_from_db()
self.assertEqual(job.status, JobState.FAILURE.value)
self.assertIsNotNone(job.finished_at)
mock_cleanup.assert_called_once_with(job)

@patch("ami.jobs.tasks.cleanup_async_job_if_needed")
@patch("celery.result.AsyncResult")
def test_revokes_success_with_incomplete_progress(self, mock_async_result, mock_cleanup):
"""async_api job where Celery reports SUCCESS but progress is incomplete is revoked."""
from celery import states

mock_async_result.return_value.state = states.SUCCESS
job = self._create_job(status=JobState.STARTED, task_id="some-celery-task-id")
Job.objects.filter(pk=job.pk).update(dispatch_mode=JobDispatchMode.ASYNC_API)
job.refresh_from_db()
# job.progress.is_complete() returns False by default (no stages completed)

results = check_stale_jobs()

self.assertEqual(len(results), 1)
self.assertEqual(results[0]["action"], "revoked")
job.refresh_from_db()
self.assertEqual(job.status, JobState.REVOKED.value)
mock_cleanup.assert_called_once_with(job)

@patch("ami.jobs.tasks.cleanup_async_job_if_needed")
@patch("celery.result.AsyncResult")
def test_revokes_when_celery_lookup_fails(self, mock_async_result, mock_cleanup):
"""Job is revoked if Celery state lookup raises an exception."""
mock_async_result.side_effect = ConnectionError("broker down")
job = self._create_job(status=JobState.STARTED, task_id="unreachable-task")

results = check_stale_jobs()

self.assertEqual(len(results), 1)
self.assertEqual(results[0]["action"], "revoked")
job.refresh_from_db()
self.assertEqual(job.status, JobState.REVOKED.value)
mock_cleanup.assert_called_once_with(job)

@patch("ami.jobs.tasks.cleanup_async_job_if_needed")
def test_skips_recent_and_final_state_jobs(self, mock_cleanup):
"""Recent jobs and jobs in final states are not touched."""
self._create_job(status=JobState.STARTED, hours_ago=1) # recent
self._create_job(status=JobState.SUCCESS, hours_ago=200) # final state

results = check_stale_jobs()

self.assertEqual(results, [])
mock_cleanup.assert_not_called()