diff --git a/ami/jobs/management/commands/update_stale_jobs.py b/ami/jobs/management/commands/update_stale_jobs.py index da3a53e3d..0f19933dd 100644 --- a/ami/jobs/management/commands/update_stale_jobs.py +++ b/ami/jobs/management/commands/update_stale_jobs.py @@ -1,38 +1,37 @@ -from celery import states -from celery.result import AsyncResult from django.core.management.base import BaseCommand -from django.utils import timezone -from ami.jobs.models import Job, JobState +from ami.jobs.models import Job +from ami.jobs.tasks import check_stale_jobs class Command(BaseCommand): - help = ( - "Update the status of all jobs that are not in a final state " "and have not been updated in the last X hours." - ) + help = "Revoke stale jobs that have not been updated within the cutoff period." - # Add argument for the number of hours to consider a job stale def add_arguments(self, parser): parser.add_argument( "--hours", type=int, default=Job.FAILED_CUTOFF_HOURS, - help="Number of hours to consider a job stale", + help="Number of hours to consider a job stale (default: %(default)s)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Show what would be done without making changes", ) def handle(self, *args, **options): - stale_jobs = Job.objects.filter( - status__in=JobState.running_states(), - updated_at__lt=timezone.now() - timezone.timedelta(hours=options["hours"]), - ) + results = check_stale_jobs(hours=options["hours"], dry_run=options["dry_run"]) + + if not results: + self.stdout.write("No stale jobs found.") + return - for job in stale_jobs: - task = AsyncResult(job.task_id) if job.task_id else None - if task: - job.update_status(task.state, save=False) - job.save() - self.stdout.write(self.style.SUCCESS(f"Updated status of job {job.pk} to {task.state}")) + prefix = "[dry-run] " if options["dry_run"] else "" + for r in results: + if r["action"] == "updated": + self.stdout.write( + self.style.SUCCESS(f"{prefix}Job {r['job_id']}: updated to {r['state']} (from Celery)") + ) else: - self.stdout.write(self.style.WARNING(f"Job {job.pk} has no associated task, setting status to FAILED")) - job.update_status(states.FAILURE, save=False) - job.save() + self.stdout.write(self.style.WARNING(f"{prefix}Job {r['job_id']}: revoked (no known Celery state)")) diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 917608be0..d8ff89d39 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -316,6 +316,97 @@ def _update_job_progress( cleanup_async_job_if_needed(job) +def check_stale_jobs(hours: int | None = None, dry_run: bool = False) -> list[dict]: + """ + Find jobs stuck in a running state past the cutoff and revoke them. + + For each stale job, checks Celery for a terminal task status. REVOKED is + always trusted. For async_api jobs, SUCCESS and FAILURE are only accepted + when job.progress.is_complete() — NATS workers may still be delivering + results after the Celery task finishes. All other cases result in revocation. + Async resources (NATS/Redis) are cleaned up in both branches. + + Returns a list of dicts describing what was done to each job. + """ + import datetime + + from celery import states + from celery.result import AsyncResult + from django.db import transaction + + from ami.jobs.models import Job, JobDispatchMode, JobState + + if hours is None: + hours = Job.FAILED_CUTOFF_HOURS + + cutoff = datetime.datetime.now() - datetime.timedelta(hours=hours) + stale_pks = list( + Job.objects.filter( + status__in=JobState.running_states(), + updated_at__lt=cutoff, + ).values_list("pk", flat=True) + ) + + results = [] + for pk in stale_pks: + with transaction.atomic(): + try: + job = Job.objects.select_for_update().get( + pk=pk, + status__in=JobState.running_states(), + updated_at__lt=cutoff, + ) + except Job.DoesNotExist: + # Another concurrent run already handled this job. + continue + + celery_state = None + if job.task_id: + try: + celery_state = AsyncResult(job.task_id).state + except Exception: + logger.warning( + "Failed to fetch Celery state for stale job %s (task_id=%s)", + job.pk, + job.task_id, + exc_info=True, + ) + # Treat as unknown state — job will be revoked below. + + # Only trust terminal Celery states. For async_api jobs, SUCCESS and + # FAILURE are only accepted when progress is complete — NATS workers may + # still be delivering results after the Celery task finishes. + is_terminal = celery_state in states.READY_STATES + is_async_api = job.dispatch_mode == JobDispatchMode.ASYNC_API + if is_async_api and celery_state in {states.SUCCESS, states.FAILURE} and not job.progress.is_complete(): + is_terminal = False + + previous_status = job.status + if is_terminal: + if not dry_run: + job.update_status(celery_state, save=False) + job.finished_at = datetime.datetime.now() + job.save() + else: + if not dry_run: + job.update_status(JobState.REVOKED, save=False) + job.finished_at = datetime.datetime.now() + job.save() + + # Async resource cleanup runs outside the transaction — it makes network + # calls (NATS/Redis) that should not hold the DB row lock. + if not dry_run: + job.refresh_from_db() + cleanup_async_job_if_needed(job) + + if is_terminal: + results.append({"job_id": job.pk, "action": "updated", "state": celery_state}) + else: + results.append({"job_id": job.pk, "action": "revoked", "previous_status": previous_status}) + + return results + + def cleanup_async_job_if_needed(job) -> None: """ Clean up async resources (NATS/Redis) if this job uses them. diff --git a/ami/jobs/tests/__init__.py b/ami/jobs/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ami/jobs/tests.py b/ami/jobs/tests/test_jobs.py similarity index 100% rename from ami/jobs/tests.py rename to ami/jobs/tests/test_jobs.py diff --git a/ami/jobs/test_tasks.py b/ami/jobs/tests/test_tasks.py similarity index 100% rename from ami/jobs/test_tasks.py rename to ami/jobs/tests/test_tasks.py diff --git a/ami/jobs/tests/test_update_stale_jobs.py b/ami/jobs/tests/test_update_stale_jobs.py new file mode 100644 index 000000000..4a1e44427 --- /dev/null +++ b/ami/jobs/tests/test_update_stale_jobs.py @@ -0,0 +1,123 @@ +from datetime import timedelta +from unittest.mock import patch + +from django.test import TestCase +from django.utils import timezone + +from ami.jobs.models import Job, JobDispatchMode, JobState +from ami.jobs.tasks import check_stale_jobs +from ami.main.models import Project + + +class CheckStaleJobsTest(TestCase): + def setUp(self): + self.project = Project.objects.create(name="Stale jobs test project") + + def _create_job(self, status=JobState.STARTED, hours_ago=100, task_id=None): + job = Job.objects.create( + project=self.project, + name=f"Test job {status}", + status=status, + ) + Job.objects.filter(pk=job.pk).update( + updated_at=timezone.now() - timedelta(hours=hours_ago), + ) + if task_id is not None: + Job.objects.filter(pk=job.pk).update(task_id=task_id) + job.refresh_from_db() + return job + + @patch("ami.jobs.tasks.cleanup_async_job_if_needed") + def test_dry_run(self, mock_cleanup): + """Dry run returns results without modifying jobs.""" + job = self._create_job(status=JobState.STARTED) + + results = check_stale_jobs(dry_run=True) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["action"], "revoked") + job.refresh_from_db() + self.assertEqual(job.status, JobState.STARTED.value) + mock_cleanup.assert_not_called() + + @patch("ami.jobs.tasks.cleanup_async_job_if_needed") + def test_revokes_stale_job(self, mock_cleanup): + """Stale job without a known Celery state is revoked and cleaned up.""" + job = self._create_job(status=JobState.STARTED) + + results = check_stale_jobs() + + self.assertEqual(len(results), 1) + result = results[0] + self.assertEqual(result["action"], "revoked") + self.assertEqual(result["previous_status"], JobState.STARTED) + job.refresh_from_db() + self.assertEqual(job.status, JobState.REVOKED.value) + self.assertIsNotNone(job.finished_at) + mock_cleanup.assert_called_once_with(job) + + @patch("ami.jobs.tasks.cleanup_async_job_if_needed") + @patch("celery.result.AsyncResult") + def test_updates_status_from_known_celery_state(self, mock_async_result, mock_cleanup): + """Stale job with a terminal Celery state is updated (not revoked).""" + from celery import states + + mock_async_result.return_value.state = states.FAILURE + job = self._create_job(status=JobState.STARTED, task_id="some-celery-task-id") + + results = check_stale_jobs() + + self.assertEqual(len(results), 1) + result = results[0] + self.assertEqual(result["action"], "updated") + self.assertEqual(result["state"], states.FAILURE) + job.refresh_from_db() + self.assertEqual(job.status, JobState.FAILURE.value) + self.assertIsNotNone(job.finished_at) + mock_cleanup.assert_called_once_with(job) + + @patch("ami.jobs.tasks.cleanup_async_job_if_needed") + @patch("celery.result.AsyncResult") + def test_revokes_success_with_incomplete_progress(self, mock_async_result, mock_cleanup): + """async_api job where Celery reports SUCCESS but progress is incomplete is revoked.""" + from celery import states + + mock_async_result.return_value.state = states.SUCCESS + job = self._create_job(status=JobState.STARTED, task_id="some-celery-task-id") + Job.objects.filter(pk=job.pk).update(dispatch_mode=JobDispatchMode.ASYNC_API) + job.refresh_from_db() + # job.progress.is_complete() returns False by default (no stages completed) + + results = check_stale_jobs() + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["action"], "revoked") + job.refresh_from_db() + self.assertEqual(job.status, JobState.REVOKED.value) + mock_cleanup.assert_called_once_with(job) + + @patch("ami.jobs.tasks.cleanup_async_job_if_needed") + @patch("celery.result.AsyncResult") + def test_revokes_when_celery_lookup_fails(self, mock_async_result, mock_cleanup): + """Job is revoked if Celery state lookup raises an exception.""" + mock_async_result.side_effect = ConnectionError("broker down") + job = self._create_job(status=JobState.STARTED, task_id="unreachable-task") + + results = check_stale_jobs() + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["action"], "revoked") + job.refresh_from_db() + self.assertEqual(job.status, JobState.REVOKED.value) + mock_cleanup.assert_called_once_with(job) + + @patch("ami.jobs.tasks.cleanup_async_job_if_needed") + def test_skips_recent_and_final_state_jobs(self, mock_cleanup): + """Recent jobs and jobs in final states are not touched.""" + self._create_job(status=JobState.STARTED, hours_ago=1) # recent + self._create_job(status=JobState.SUCCESS, hours_ago=200) # final state + + results = check_stale_jobs() + + self.assertEqual(results, []) + mock_cleanup.assert_not_called()