Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from asgiref.sync import async_to_sync, sync_to_async
from celery.signals import task_failure, task_postrun, task_prerun
from django.db import transaction
from django.db.models import Q
from redis.exceptions import RedisError

from ami.main.checks.schemas import IntegrityCheckResult
Expand All @@ -34,6 +35,60 @@
# "nobody's listening" signal.
WORKER_AVAILABILITY_ONLINE_CUTOFF = datetime.timedelta(minutes=5)

# Minimum interval between heartbeat DB writes for a given (pipeline, project).
# PROCESSING_SERVICE_LAST_SEEN_MAX is 60s; writing at most once per 30s keeps
# shared last_seen rows current without hammering them on every concurrent
# request for the same pipeline within a project.
HEARTBEAT_THROTTLE_SECONDS = 30
HEARTBEAT_TASK_EXPIRES_SECONDS = HEARTBEAT_THROTTLE_SECONDS * 2


@celery_app.task(
soft_time_limit=10,
time_limit=15,
Comment thread
mihow marked this conversation as resolved.
expires=HEARTBEAT_TASK_EXPIRES_SECONDS,
ignore_result=True,
# No retries — a missed heartbeat is benign; retrying adds load for no gain.
)
def update_pipeline_pull_services_seen(job_id: int, seen_at_iso: str | None = None) -> None:
"""
Fire-and-forget heartbeat task: record last_seen/last_seen_live for async
(pull-mode) processing services linked to a job's pipeline.

Called via .delay() from the tasks and result view endpoints so the HTTP
request is never blocked on this DB write.

Throttle: skips the UPDATE if every matching service in the shared
(pipeline, project) scope was seen within HEARTBEAT_THROTTLE_SECONDS,
cutting write rate under concurrent requests by orders of magnitude while
keeping last_seen fresh relative to the 60s
PROCESSING_SERVICE_LAST_SEEN_MAX threshold.

Scope: marks ALL async services on the pipeline within this project as live,
not just the specific service that made the request. Once application-token
auth is available (PR #1117), this should be scoped to the individual
calling service instead.
"""
from ami.jobs.models import Job # avoid circular import

try:
job = Job.objects.select_related("pipeline").get(pk=job_id)
except Job.DoesNotExist:
return

if not job.pipeline_id:
return

seen_at = datetime.datetime.fromisoformat(seen_at_iso) if seen_at_iso is not None else datetime.datetime.now()
throttle_cutoff = seen_at - datetime.timedelta(seconds=HEARTBEAT_THROTTLE_SECONDS)

services_qs = job.pipeline.processing_services.async_services().filter(projects=job.project_id)
stale_services_qs = services_qs.filter(Q(last_seen__isnull=True) | Q(last_seen__lt=throttle_cutoff))
if not stale_services_qs.exists():
return

stale_services_qs.update(last_seen=seen_at, last_seen_live=True)


@celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit)
def run_job(self, job_id: int) -> None:
Expand Down
212 changes: 212 additions & 0 deletions ami/jobs/tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ami.jobs.models import Job, JobDispatchMode, JobProgress, JobState, MLJob, SourceImageCollectionPopulateJob
from ami.main.models import Project, SourceImage, SourceImageCollection
from ami.ml.models import Pipeline
from ami.ml.models.processing_service import ProcessingService
from ami.ml.orchestration.jobs import queue_images_to_nats
from ami.users.models import User

Expand Down Expand Up @@ -1016,3 +1017,214 @@ def test_tasks_endpoint_rejects_non_async_jobs(self):
resp = self.client.post(tasks_url, {"batch_size": 1}, format="json")
self.assertEqual(resp.status_code, 400)
self.assertIn("async_api", resp.json()[0].lower())


class TestPipelineHeartbeatTask(APITestCase):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a of tests for a "nice to have" feature that is slowing down our required core features. What is necessary?

"""
Unit tests for update_pipeline_pull_services_seen and the view-level
_mark_pipeline_pull_services_seen fire-and-forget dispatch.
"""

def setUp(self):
self.project = Project.objects.create(name="Heartbeat Test Project")
self.pipeline = Pipeline.objects.create(name="Heartbeat Pipeline", slug="heartbeat-pipeline")
self.pipeline.projects.add(self.project)
self.collection = SourceImageCollection.objects.create(name="HB Collection", project=self.project)
self.job = Job.objects.create(
job_type_key=MLJob.key,
project=self.project,
name="Heartbeat Test Job",
pipeline=self.pipeline,
source_image_collection=self.collection,
dispatch_mode=JobDispatchMode.ASYNC_API,
)
self.service = ProcessingService.objects.create(
name="Heartbeat Worker",
endpoint_url=None, # None = pull-mode / async service
)
self.service.pipelines.add(self.pipeline)
self.service.projects.add(self.project)

def test_tasks_endpoint_dispatches_heartbeat_task(self):
"""The /tasks endpoint calls update_pipeline_pull_services_seen.delay(), not the DB directly."""
from unittest.mock import ANY, patch

job = self.job
job.status = JobState.STARTED
job.save(update_fields=["status"])

images = [
SourceImage.objects.create(
path=f"hb_tasks_{i}.jpg",
public_base_url="http://example.com",
project=self.project,
)
for i in range(2)
]
queue_images_to_nats(job, images)

user = User.objects.create_user(email="hbtest@example.com", is_superuser=True, is_active=True)
self.client.force_authenticate(user=user)

with patch("ami.jobs.views.update_pipeline_pull_services_seen.delay") as mock_delay:
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch_size": 1}, format="json")

self.assertEqual(resp.status_code, 200)
mock_delay.assert_called_once_with(job.pk, seen_at_iso=ANY)

def test_result_endpoint_dispatches_heartbeat_task(self):
"""The /result endpoint calls update_pipeline_pull_services_seen.delay(), not the DB directly."""
from unittest.mock import ANY, MagicMock, patch

user = User.objects.create_user(email="hbresult@example.com", is_superuser=True, is_active=True)
self.client.force_authenticate(user=user)

result_data = {
"results": [
{
"reply_subject": "test.reply.hb",
"result": {
"pipeline": "heartbeat-pipeline",
"algorithms": {},
"total_time": 0.1,
"source_images": [],
"detections": [],
"errors": None,
},
}
]
}

mock_async_result = MagicMock()
mock_async_result.id = "hb-task-id"
with (
patch("ami.jobs.views.process_nats_pipeline_result.delay", return_value=mock_async_result),
patch("ami.jobs.views.update_pipeline_pull_services_seen.delay") as mock_delay,
):
result_url = reverse_with_params(
"api:job-result", args=[self.job.pk], params={"project_id": self.project.pk}
)
resp = self.client.post(result_url, result_data, format="json")

self.assertEqual(resp.status_code, 200)
mock_delay.assert_called_once_with(self.job.pk, seen_at_iso=ANY)

def test_tasks_endpoint_tolerates_heartbeat_dispatch_failure(self):
"""Heartbeat enqueue errors should not fail the /tasks response."""
from unittest.mock import patch

from kombu.exceptions import OperationalError

job = self.job
job.status = JobState.STARTED
job.save(update_fields=["status"])

image = SourceImage.objects.create(
path="hb_tasks_broker.jpg",
public_base_url="http://example.com",
project=self.project,
)
queue_images_to_nats(job, [image])

user = User.objects.create_user(email="hbbroker@example.com", is_superuser=True, is_active=True)
self.client.force_authenticate(user=user)

with patch(
"ami.jobs.views.update_pipeline_pull_services_seen.delay",
side_effect=OperationalError("broker unavailable"),
):
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch_size": 1}, format="json")

self.assertEqual(resp.status_code, 200)
self.assertEqual(len(resp.json()["tasks"]), 1)

def test_heartbeat_task_updates_last_seen_when_stale(self):
"""update_pipeline_pull_services_seen writes last_seen when the service is stale."""
import datetime

from ami.jobs.tasks import update_pipeline_pull_services_seen

# Set last_seen to well past the throttle window
old_time = datetime.datetime.now() - datetime.timedelta(minutes=5)
self.service.last_seen = old_time
self.service.last_seen_live = False
self.service.save(update_fields=["last_seen", "last_seen_live"])

seen_at = datetime.datetime.now()
update_pipeline_pull_services_seen(self.job.pk, seen_at_iso=seen_at.isoformat())

self.service.refresh_from_db()
self.assertTrue(self.service.last_seen_live)
self.assertEqual(self.service.last_seen, seen_at)

def test_heartbeat_task_skips_update_when_recent(self):
"""update_pipeline_pull_services_seen skips the UPDATE when last_seen is within the throttle window."""
import datetime

from ami.jobs.tasks import update_pipeline_pull_services_seen

# Set last_seen to just now — well inside the 30s throttle window
recent_time = datetime.datetime.now() - datetime.timedelta(seconds=5)
self.service.last_seen = recent_time
self.service.last_seen_live = True
self.service.save(update_fields=["last_seen", "last_seen_live"])

update_pipeline_pull_services_seen(self.job.pk, seen_at_iso=datetime.datetime.now().isoformat())

self.service.refresh_from_db()
# last_seen should not have advanced significantly (throttle skipped the UPDATE)
self.assertAlmostEqual(
self.service.last_seen.timestamp(),
recent_time.timestamp(),
delta=1.0,
)

def test_heartbeat_task_no_op_for_missing_job(self):
"""update_pipeline_pull_services_seen silently returns when job_id does not exist."""
from ami.jobs.tasks import update_pipeline_pull_services_seen

# Should not raise
update_pipeline_pull_services_seen(job_id=999999)

def test_heartbeat_task_no_op_for_job_without_pipeline(self):
"""update_pipeline_pull_services_seen returns early when job has no pipeline."""
import datetime

from ami.jobs.tasks import update_pipeline_pull_services_seen

job_no_pipeline = Job.objects.create(
job_type_key=MLJob.key,
project=self.project,
name="No-pipeline job",
source_image_collection=self.collection,
dispatch_mode=JobDispatchMode.ASYNC_API,
)

old_time = datetime.datetime.now() - datetime.timedelta(minutes=10)
self.service.last_seen = old_time
self.service.save(update_fields=["last_seen"])

update_pipeline_pull_services_seen(job_no_pipeline.pk, seen_at_iso=datetime.datetime.now().isoformat())

# Service last_seen should be unchanged because the task returned early
self.service.refresh_from_db()
self.assertAlmostEqual(self.service.last_seen.timestamp(), old_time.timestamp(), delta=1.0)

def test_heartbeat_task_does_not_regress_newer_last_seen(self):
"""Delayed heartbeats must not overwrite a newer last_seen value."""
import datetime

from ami.jobs.tasks import update_pipeline_pull_services_seen

newer_time = datetime.datetime.now()
delayed_seen_at = newer_time - datetime.timedelta(minutes=1)
self.service.last_seen = newer_time
self.service.last_seen_live = True
self.service.save(update_fields=["last_seen", "last_seen_live"])

update_pipeline_pull_services_seen(self.job.pk, seen_at_iso=delayed_seen_at.isoformat())

self.service.refresh_from_db()
self.assertEqual(self.service.last_seen, newer_time)
39 changes: 21 additions & 18 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import datetime
import logging

import kombu.exceptions
Expand All @@ -24,7 +25,7 @@
MLJobTasksRequestSerializer,
MLJobTasksResponseSerializer,
)
from ami.jobs.tasks import process_nats_pipeline_result
from ami.jobs.tasks import process_nats_pipeline_result, update_pipeline_pull_services_seen
from ami.main.api.schemas import project_id_doc_param
from ami.main.api.views import DefaultViewSet
from ami.utils.fields import url_boolean_param
Expand Down Expand Up @@ -52,26 +53,28 @@ def _actor_log_context(request) -> tuple[str, str | None]:

def _mark_pipeline_pull_services_seen(job: "Job") -> None:
"""
Record a heartbeat for async (pull-mode) processing services linked to the job's pipeline.

Called on every task-fetch and result-submit request so that the worker's polling activity
keeps last_seen/last_seen_live current. The periodic check_processing_services_online task
will mark services offline if this heartbeat stops arriving within PROCESSING_SERVICE_LAST_SEEN_MAX.

IMPORTANT: This marks ALL async services on the pipeline within this project as live, not just
the specific service that made the request. If multiple async services share the same pipeline
within a project, a single worker polling will keep all of them appearing online.
Once application-token auth is available (PR #1117), this should be scoped to the individual
calling service instead.
Enqueue a fire-and-forget heartbeat for async (pull-mode) processing services
linked to the job's pipeline.

Dispatches update_pipeline_pull_services_seen via Celery .delay() so the view
is never blocked on the DB write. The task throttles writes to at most once per
~30 seconds per pipeline within this project, keeping last_seen current
relative to the 60s PROCESSING_SERVICE_LAST_SEEN_MAX threshold without
hammering the same rows on every concurrent task-fetch or result-submit
request.

Per-service scoping is not yet possible — marks ALL async services on the
pipeline within this project as live. Once application-token auth lands
(PR #1117) this can be scoped to the individual calling service.
"""
import datetime

if not job.pipeline_id:
return
job.pipeline.processing_services.async_services().filter(projects=job.project_id).update(
last_seen=datetime.datetime.now(),
last_seen_live=True,
)
try:
update_pipeline_pull_services_seen.delay(job.pk, seen_at_iso=datetime.datetime.now().isoformat())
except (kombu.exceptions.KombuError, ConnectionError, OSError) as exc:
msg = f"Failed to enqueue non-critical pipeline heartbeat for job {job.pk}: {exc}"
logger.warning(msg)
job.logger.warning(msg)


class JobFilterSet(filters.FilterSet):
Expand Down
Loading