Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions ami/jobs/management/commands/update_stale_jobs.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,37 @@
from celery import states
from celery.result import AsyncResult
from django.core.management.base import BaseCommand
from django.utils import timezone

from ami.jobs.models import Job, JobState
from ami.jobs.models import Job
from ami.jobs.tasks import check_stale_jobs


class Command(BaseCommand):
help = (
"Update the status of all jobs that are not in a final state " "and have not been updated in the last X hours."
)
help = "Revoke stale jobs that have not been updated within the cutoff period."

# Add argument for the number of hours to consider a job stale
def add_arguments(self, parser):
parser.add_argument(
"--hours",
type=int,
default=Job.FAILED_CUTOFF_HOURS,
help="Number of hours to consider a job stale",
help="Number of hours to consider a job stale (default: %(default)s)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be done without making changes",
)

def handle(self, *args, **options):
stale_jobs = Job.objects.filter(
status__in=JobState.running_states(),
updated_at__lt=timezone.now() - timezone.timedelta(hours=options["hours"]),
)
results = check_stale_jobs(hours=options["hours"], dry_run=options["dry_run"])

if not results:
self.stdout.write("No stale jobs found.")
return

for job in stale_jobs:
task = AsyncResult(job.task_id) if job.task_id else None
if task:
job.update_status(task.state, save=False)
job.save()
self.stdout.write(self.style.SUCCESS(f"Updated status of job {job.pk} to {task.state}"))
prefix = "[dry-run] " if options["dry_run"] else ""
for r in results:
if r["action"] == "updated":
self.stdout.write(
self.style.SUCCESS(f"{prefix}Job {r['job_id']}: updated to {r['state']} (from Celery)")
)
else:
self.stdout.write(self.style.WARNING(f"Job {job.pk} has no associated task, setting status to FAILED"))
job.update_status(states.FAILURE, save=False)
job.save()
self.stdout.write(self.style.WARNING(f"{prefix}Job {r['job_id']}: revoked (no known Celery state)"))
50 changes: 50 additions & 0 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,56 @@ def _update_job_progress(
cleanup_async_job_if_needed(job)


def check_stale_jobs(hours: int | None = None, dry_run: bool = False) -> list[dict]:
"""
Find jobs stuck in a running state past the cutoff and revoke them.

For each stale job, checks Celery for a real task status. If Celery has one
(e.g. SUCCESS, FAILURE), uses that. Otherwise revokes the job and cleans up
any async resources (NATS/Redis).

Returns a list of dicts describing what was done to each job.
"""
import datetime

from celery import states
from celery.result import AsyncResult

from ami.jobs.models import Job, JobState

if hours is None:
hours = Job.FAILED_CUTOFF_HOURS

known_celery_states = frozenset(states.ALL_STATES) - {states.PENDING}

cutoff = datetime.datetime.now() - datetime.timedelta(hours=hours)
stale_jobs = Job.objects.filter(
status__in=JobState.running_states(),
updated_at__lt=cutoff,
)

results = []
for job in stale_jobs:
celery_state = None
if job.task_id:
celery_state = AsyncResult(job.task_id).state

Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
if celery_state in known_celery_states:
if not dry_run:
job.update_status(celery_state, save=False)
job.save()
results.append({"job_id": job.pk, "action": "updated", "state": celery_state})
Comment thread
mihow marked this conversation as resolved.
Outdated
else:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
if not dry_run:
job.update_status(JobState.REVOKED, save=False)
job.finished_at = datetime.datetime.now()
job.save()
cleanup_async_job_if_needed(job)
results.append({"job_id": job.pk, "action": "revoked", "previous_status": job.status})
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

return results


def cleanup_async_job_if_needed(job) -> None:
"""
Clean up async resources (NATS/Redis) if this job uses them.
Expand Down
66 changes: 66 additions & 0 deletions ami/jobs/tests_update_stale_jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from datetime import timedelta
from unittest.mock import patch

from django.test import TestCase
from django.utils import timezone

from ami.jobs.models import Job, JobState
from ami.jobs.tasks import check_stale_jobs
from ami.main.models import Project


class CheckStaleJobsTest(TestCase):
def setUp(self):
self.project = Project.objects.create(name="Stale jobs test project")

def _create_job(self, status=JobState.STARTED, hours_ago=100, task_id=None):
job = Job.objects.create(
project=self.project,
name=f"Test job {status}",
status=status,
)
Job.objects.filter(pk=job.pk).update(
updated_at=timezone.now() - timedelta(hours=hours_ago),
)
if task_id is not None:
Job.objects.filter(pk=job.pk).update(task_id=task_id)
job.refresh_from_db()
return job

@patch("ami.jobs.tasks.cleanup_async_job_if_needed")
def test_dry_run(self, mock_cleanup):
"""Dry run returns results without modifying jobs."""
job = self._create_job(status=JobState.STARTED)

results = check_stale_jobs(dry_run=True)

self.assertEqual(len(results), 1)
self.assertEqual(results[0]["action"], "revoked")
job.refresh_from_db()
self.assertEqual(job.status, JobState.STARTED.value)
mock_cleanup.assert_not_called()

@patch("ami.jobs.tasks.cleanup_async_job_if_needed")
def test_revokes_stale_job(self, mock_cleanup):
"""Stale job without a known Celery state is revoked and cleaned up."""
job = self._create_job(status=JobState.STARTED)

results = check_stale_jobs()

self.assertEqual(len(results), 1)
self.assertEqual(results[0]["action"], "revoked")
job.refresh_from_db()
self.assertEqual(job.status, JobState.REVOKED.value)
self.assertIsNotNone(job.finished_at)
mock_cleanup.assert_called_once_with(job)

@patch("ami.jobs.tasks.cleanup_async_job_if_needed")
def test_skips_recent_and_final_state_jobs(self, mock_cleanup):
"""Recent jobs and jobs in final states are not touched."""
self._create_job(status=JobState.STARTED, hours_ago=1) # recent
self._create_job(status=JobState.SUCCESS, hours_ago=200) # final state

results = check_stale_jobs()

self.assertEqual(results, [])
mock_cleanup.assert_not_called()