Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ami/jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,13 @@ class Job(BaseModel):
# N minutes". 10 is conservative; raise if legitimate long-running jobs get
# reaped.
STALLED_JOBS_MAX_MINUTES = 10
# Zombie-stream reaper: age threshold above which a NATS stream for a job
# in a terminal state (or missing from Django) is considered safe to drop.
# Kept well above :attr:`STALLED_JOBS_MAX_MINUTES` so newly-dispatched jobs
# whose stream was created before ``transaction.on_commit`` saved the Job
# row do not get reaped. Tighten only if ``cleanup-on-cancel`` misses are
# still stranding consumer poll cycles after this safety net lands.
ZOMBIE_STREAMS_MAX_AGE_MINUTES = STALLED_JOBS_MAX_MINUTES * 6

name = models.CharField(max_length=255)
queue = models.CharField(max_length=255, default="default")
Expand Down
89 changes: 88 additions & 1 deletion ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Callable
from typing import TYPE_CHECKING

from asgiref.sync import async_to_sync
from asgiref.sync import async_to_sync, sync_to_async
from celery.signals import task_failure, task_postrun, task_prerun
from django.db import transaction
from redis.exceptions import RedisError
Expand Down Expand Up @@ -558,6 +558,7 @@ class JobsHealthCheckResult:

stale_jobs: IntegrityCheckResult
running_job_snapshots: IntegrityCheckResult
zombie_streams: IntegrityCheckResult


def _run_stale_jobs_check() -> IntegrityCheckResult:
Expand Down Expand Up @@ -635,6 +636,91 @@ async def _snapshot_all() -> None:
return IntegrityCheckResult(checked=len(running_jobs), fixed=0, unfixable=errors)


def _run_zombie_streams_check() -> IntegrityCheckResult:
"""Drain NATS streams that outlived their Django Job.

Defense-in-depth for the cleanup-on-cancel path: a stream whose Job is in
a terminal state (or was deleted) is consuming worker poll cycles for no
reason. The age guard (``Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES``) prevents
races with freshly-dispatched jobs whose NATS stream is created before
``transaction.on_commit`` persists the Job row.

Observations-only for healthy in-flight jobs; only drains when both
conditions hold:

* Job is ``None`` or in :meth:`JobState.final_states`
* Stream's NATS-reported ``created`` timestamp is older than the threshold

``checked`` counts job-shaped streams inspected; ``fixed`` counts those
actually drained; ``unfixable`` counts per-stream drain failures.
"""
from ami.jobs.models import Job, JobState

threshold = datetime.timedelta(minutes=Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES)
now = datetime.datetime.now()

async def _drain_all() -> tuple[int, int, int]:
async with TaskQueueManager() as manager:
snapshots = await manager.list_job_stream_snapshots()
if not snapshots:
return 0, 0, 0

job_ids = [s["job_id"] for s in snapshots]
jobs_by_id = await sync_to_async(
lambda ids: {j.pk: j for j in Job.objects.filter(pk__in=ids).only("pk", "status")}
)(job_ids)

checked = len(snapshots)
drained = 0
errored = 0
for snap in snapshots:
created = snap["created"]
age = now - created if created else threshold + datetime.timedelta(minutes=1)
Comment thread
mihow marked this conversation as resolved.
Outdated
if age < threshold:
continue
job = jobs_by_id.get(snap["job_id"])
job_status = job.status if job else None
if job is not None and JobState(job_status) not in JobState.final_states():
continue
status_label = str(job_status) if job else "missing"
try:
consumer_deleted = await manager.delete_consumer(snap["job_id"])
stream_deleted = await manager.delete_stream(snap["job_id"])
except Exception:
errored += 1
logger.exception("Failed draining zombie NATS stream for job %s", snap["job_id"])
continue
if stream_deleted:
drained += 1
age_hours = age.total_seconds() / 3600.0
logger.info(
"Drained zombie NATS stream %s (status=%s, age=%.1fh, redelivered=%s, consumer_deleted=%s)",
snap["stream_name"],
status_label,
age_hours,
snap["num_redelivered"],
consumer_deleted,
)
else:
errored += 1
return checked, drained, errored

try:
checked, drained, errored = async_to_sync(_drain_all)()
except Exception:
logger.exception("zombie_streams check: connection/setup failed")
return IntegrityCheckResult(checked=0, fixed=0, unfixable=1)

log_fn = logger.warning if errored else logger.info
log_fn(
"zombie_streams check: %d stream(s) inspected, %d drained, %d error(s)",
checked,
drained,
errored,
)
return IntegrityCheckResult(checked=checked, fixed=drained, unfixable=errored)


def _safe_run_sub_check(name: str, fn: Callable[[], IntegrityCheckResult]) -> IntegrityCheckResult:
"""Run one umbrella sub-check, returning an ``unfixable=1`` sentinel on failure.

Expand Down Expand Up @@ -664,6 +750,7 @@ def jobs_health_check() -> dict:
result = JobsHealthCheckResult(
stale_jobs=_safe_run_sub_check("stale_jobs", _run_stale_jobs_check),
running_job_snapshots=_safe_run_sub_check("running_job_snapshots", _run_running_job_snapshot_check),
zombie_streams=_safe_run_sub_check("zombie_streams", _run_zombie_streams_check),
)
return dataclasses.asdict(result)

Expand Down
126 changes: 126 additions & 0 deletions ami/jobs/tests/test_periodic_beat_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def _stub_manager(self, mock_manager_cls) -> AsyncMock:
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.log_consumer_stats_snapshot = AsyncMock()
# Zombie-stream sub-check defaults: no streams to inspect, no drains.
instance.list_job_stream_snapshots = AsyncMock(return_value=[])
instance.delete_consumer = AsyncMock(return_value=True)
instance.delete_stream = AsyncMock(return_value=True)
return instance

def test_reports_both_sub_check_results(self, mock_manager_cls, _mock_cleanup):
Expand All @@ -50,6 +54,7 @@ def test_reports_both_sub_check_results(self, mock_manager_cls, _mock_cleanup):
{
"stale_jobs": {"checked": 2, "fixed": 2, "unfixable": 0},
"running_job_snapshots": _empty_check_dict(),
"zombie_streams": _empty_check_dict(),
},
)

Expand All @@ -63,6 +68,7 @@ def test_idle_deployment_returns_all_zeros(self, mock_manager_cls, _mock_cleanup
{
"stale_jobs": _empty_check_dict(),
"running_job_snapshots": _empty_check_dict(),
"zombie_streams": _empty_check_dict(),
},
)

Expand Down Expand Up @@ -173,3 +179,123 @@ def __init__(self, task_id):

# checked == 2 (both stale), fixed == 2 (one per branch), unfixable == 0
self.assertEqual(result["stale_jobs"], {"checked": 2, "fixed": 2, "unfixable": 0})

def test_zombie_stream_drained_when_job_is_terminal_and_old(self, mock_manager_cls, _mock_cleanup):
"""An old stream whose Job is in a final state should be drained."""
import datetime

terminal_job = Job.objects.create(project=self.project, name="zombie owner", status=JobState.SUCCESS)
instance = self._stub_manager(mock_manager_cls)
old_ts = datetime.datetime.now() - datetime.timedelta(minutes=Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES + 5)
instance.list_job_stream_snapshots = AsyncMock(
return_value=[
{
"job_id": terminal_job.pk,
"stream_name": f"job_{terminal_job.pk}",
"created": old_ts,
"messages": 0,
"num_redelivered": 7,
}
]
)

result = jobs_health_check()

self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 1, "unfixable": 0})
instance.delete_consumer.assert_awaited_once_with(terminal_job.pk)
instance.delete_stream.assert_awaited_once_with(terminal_job.pk)

def test_zombie_stream_drained_when_job_is_missing_and_old(self, mock_manager_cls, _mock_cleanup):
"""An old stream whose Job row no longer exists should be drained."""
import datetime

instance = self._stub_manager(mock_manager_cls)
old_ts = datetime.datetime.now() - datetime.timedelta(minutes=Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES + 1)
instance.list_job_stream_snapshots = AsyncMock(
return_value=[
{
"job_id": 987654, # no Job row with this pk
"stream_name": "job_987654",
"created": old_ts,
"messages": 3,
"num_redelivered": 0,
}
]
)

result = jobs_health_check()

self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 1, "unfixable": 0})
instance.delete_stream.assert_awaited_once_with(987654)

def test_zombie_stream_not_drained_when_below_age_threshold(self, mock_manager_cls, _mock_cleanup):
"""A fresh stream for a terminal job must NOT be drained (on_commit race guard)."""
import datetime

terminal_job = Job.objects.create(project=self.project, name="fresh zombie?", status=JobState.FAILURE)
instance = self._stub_manager(mock_manager_cls)
fresh_ts = datetime.datetime.now() - datetime.timedelta(minutes=1)
instance.list_job_stream_snapshots = AsyncMock(
return_value=[
{
"job_id": terminal_job.pk,
"stream_name": f"job_{terminal_job.pk}",
"created": fresh_ts,
"messages": 0,
"num_redelivered": 0,
}
]
)

result = jobs_health_check()

self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 0, "unfixable": 0})
instance.delete_stream.assert_not_awaited()

def test_zombie_stream_not_drained_when_job_still_running(self, mock_manager_cls, _mock_cleanup):
"""An old stream for a still-running job must NOT be drained."""
import datetime

running_job = Job.objects.create(project=self.project, name="still running", status=JobState.STARTED)
instance = self._stub_manager(mock_manager_cls)
old_ts = datetime.datetime.now() - datetime.timedelta(minutes=Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES + 10)
instance.list_job_stream_snapshots = AsyncMock(
return_value=[
{
"job_id": running_job.pk,
"stream_name": f"job_{running_job.pk}",
"created": old_ts,
"messages": 5,
"num_redelivered": 0,
}
]
)

result = jobs_health_check()

self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 0, "unfixable": 0})
instance.delete_stream.assert_not_awaited()

def test_zombie_stream_drain_failure_counts_as_unfixable(self, mock_manager_cls, _mock_cleanup):
"""A drain that raises should be counted as unfixable without crashing the umbrella."""
import datetime

terminal_job = Job.objects.create(project=self.project, name="unfixable", status=JobState.SUCCESS)
instance = self._stub_manager(mock_manager_cls)
old_ts = datetime.datetime.now() - datetime.timedelta(minutes=Job.ZOMBIE_STREAMS_MAX_AGE_MINUTES + 2)
instance.list_job_stream_snapshots = AsyncMock(
return_value=[
{
"job_id": terminal_job.pk,
"stream_name": f"job_{terminal_job.pk}",
"created": old_ts,
"messages": 0,
"num_redelivered": 0,
}
]
)
instance.delete_consumer = AsyncMock(side_effect=RuntimeError("nats error"))

result = jobs_health_check()

self.assertEqual(result["zombie_streams"], {"checked": 1, "fixed": 0, "unfixable": 1})
Loading
Loading