Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,47 @@
# "nobody's listening" signal.
WORKER_AVAILABILITY_ONLINE_CUTOFF = datetime.timedelta(minutes=5)

# Minimum interval between heartbeat dispatches for a given (pipeline, project).
# The view-level Redis cache gate uses this window to skip .delay() under
# concurrent polling; the task itself does no throttling.
HEARTBEAT_THROTTLE_SECONDS = 30


@celery_app.task(
soft_time_limit=10,
time_limit=15,
Comment thread
mihow marked this conversation as resolved.
ignore_result=True,
# No retries — a missed heartbeat is benign; retrying adds load for no gain.
)
def update_pipeline_pull_services_seen(job_id: int) -> None:
"""
Fire-and-forget heartbeat task: record last_seen/last_seen_live for async
(pull-mode) processing services linked to a job's pipeline.

Throttling lives in the view (Redis cache gate over HEARTBEAT_THROTTLE_SECONDS),
so this task is dispatched at most once per (pipeline, project) per window
and can just write.

Scope: marks ALL async services on the pipeline within this project as live,
not just the specific service that made the request. Once application-token
auth is available (PR #1117), this should be scoped to the individual
calling service instead.
"""
from ami.jobs.models import Job # avoid circular import

try:
job = Job.objects.select_related("pipeline").get(pk=job_id)
except Job.DoesNotExist:
return

if not job.pipeline_id:
return

job.pipeline.processing_services.async_services().filter(projects=job.project_id).update(
last_seen=datetime.datetime.now(),
last_seen_live=True,
)


@celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit)
def run_job(self, job_id: int) -> None:
Expand Down
141 changes: 141 additions & 0 deletions ami/jobs/tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ami.jobs.models import Job, JobDispatchMode, JobProgress, JobState, MLJob, SourceImageCollectionPopulateJob
from ami.main.models import Project, SourceImage, SourceImageCollection
from ami.ml.models import Pipeline
from ami.ml.models.processing_service import ProcessingService
from ami.ml.orchestration.jobs import queue_images_to_nats
from ami.users.models import User

Expand Down Expand Up @@ -1016,3 +1017,143 @@ def test_tasks_endpoint_rejects_non_async_jobs(self):
resp = self.client.post(tasks_url, {"batch_size": 1}, format="json")
self.assertEqual(resp.status_code, 400)
self.assertIn("async_api", resp.json()[0].lower())


class TestPipelineHeartbeatTask(APITestCase):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a of tests for a "nice to have" feature that is slowing down our required core features. What is necessary?

"""
Unit tests for update_pipeline_pull_services_seen and the view-level
_mark_pipeline_pull_services_seen fire-and-forget dispatch.
"""

def setUp(self):
from django.core.cache import cache

# Cache-based gate in _mark_pipeline_pull_services_seen would otherwise
# carry over between tests and suppress the .delay() we want to assert.
cache.clear()

self.project = Project.objects.create(name="Heartbeat Test Project")
self.pipeline = Pipeline.objects.create(name="Heartbeat Pipeline", slug="heartbeat-pipeline")
self.pipeline.projects.add(self.project)
self.collection = SourceImageCollection.objects.create(name="HB Collection", project=self.project)
self.job = Job.objects.create(
job_type_key=MLJob.key,
project=self.project,
name="Heartbeat Test Job",
pipeline=self.pipeline,
source_image_collection=self.collection,
dispatch_mode=JobDispatchMode.ASYNC_API,
)
self.service = ProcessingService.objects.create(
name="Heartbeat Worker",
endpoint_url=None, # None = pull-mode / async service
)
self.service.pipelines.add(self.pipeline)
self.service.projects.add(self.project)

def test_tasks_endpoint_dispatches_heartbeat_task(self):
"""The /tasks endpoint calls update_pipeline_pull_services_seen.delay(), not the DB directly."""
from unittest.mock import patch

job = self.job
job.status = JobState.STARTED
job.save(update_fields=["status"])

images = [
SourceImage.objects.create(
path=f"hb_tasks_{i}.jpg",
public_base_url="http://example.com",
project=self.project,
)
for i in range(2)
]
queue_images_to_nats(job, images)

user = User.objects.create_user(email="hbtest@example.com", is_superuser=True, is_active=True)
self.client.force_authenticate(user=user)

with patch("ami.jobs.views.update_pipeline_pull_services_seen.delay") as mock_delay:
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch_size": 1}, format="json")

self.assertEqual(resp.status_code, 200)
mock_delay.assert_called_once_with(job.pk)

def test_result_endpoint_dispatches_heartbeat_task(self):
"""The /result endpoint calls update_pipeline_pull_services_seen.delay(), not the DB directly."""
from unittest.mock import MagicMock, patch

user = User.objects.create_user(email="hbresult@example.com", is_superuser=True, is_active=True)
self.client.force_authenticate(user=user)

result_data = {
"results": [
{
"reply_subject": "test.reply.hb",
"result": {
"pipeline": "heartbeat-pipeline",
"algorithms": {},
"total_time": 0.1,
"source_images": [],
"detections": [],
"errors": None,
},
}
]
}

mock_async_result = MagicMock()
mock_async_result.id = "hb-task-id"
with (
patch("ami.jobs.views.process_nats_pipeline_result.delay", return_value=mock_async_result),
patch("ami.jobs.views.update_pipeline_pull_services_seen.delay") as mock_delay,
):
result_url = reverse_with_params(
"api:job-result", args=[self.job.pk], params={"project_id": self.project.pk}
)
resp = self.client.post(result_url, result_data, format="json")

self.assertEqual(resp.status_code, 200)
mock_delay.assert_called_once_with(self.job.pk)

def test_tasks_endpoint_tolerates_heartbeat_dispatch_failure(self):
"""Heartbeat enqueue errors should not fail the /tasks response."""
from unittest.mock import patch

from kombu.exceptions import OperationalError

job = self.job
job.status = JobState.STARTED
job.save(update_fields=["status"])

image = SourceImage.objects.create(
path="hb_tasks_broker.jpg",
public_base_url="http://example.com",
project=self.project,
)
queue_images_to_nats(job, [image])

user = User.objects.create_user(email="hbbroker@example.com", is_superuser=True, is_active=True)
self.client.force_authenticate(user=user)

with patch(
"ami.jobs.views.update_pipeline_pull_services_seen.delay",
side_effect=OperationalError("broker unavailable"),
):
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch_size": 1}, format="json")

self.assertEqual(resp.status_code, 200)
self.assertEqual(len(resp.json()["tasks"]), 1)

def test_view_gate_suppresses_redundant_dispatches(self):
"""Rapid repeated calls to _mark_pipeline_pull_services_seen should only enqueue once per window."""
from unittest.mock import patch

from ami.jobs.views import _mark_pipeline_pull_services_seen

with patch("ami.jobs.views.update_pipeline_pull_services_seen.delay") as mock_delay:
for _ in range(5):
_mark_pipeline_pull_services_seen(self.job)

self.assertEqual(mock_delay.call_count, 1)
42 changes: 24 additions & 18 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import kombu.exceptions
import nats.errors
from asgiref.sync import async_to_sync
from django.core.cache import cache
from django.db.models import Q
from django.db.models.query import QuerySet
from django.forms import IntegerField
Expand All @@ -24,7 +25,7 @@
MLJobTasksRequestSerializer,
MLJobTasksResponseSerializer,
)
from ami.jobs.tasks import process_nats_pipeline_result
from ami.jobs.tasks import HEARTBEAT_THROTTLE_SECONDS, process_nats_pipeline_result, update_pipeline_pull_services_seen
from ami.main.api.schemas import project_id_doc_param
from ami.main.api.views import DefaultViewSet
from ami.utils.fields import url_boolean_param
Expand Down Expand Up @@ -52,26 +53,31 @@ def _actor_log_context(request) -> tuple[str, str | None]:

def _mark_pipeline_pull_services_seen(job: "Job") -> None:
"""
Record a heartbeat for async (pull-mode) processing services linked to the job's pipeline.

Called on every task-fetch and result-submit request so that the worker's polling activity
keeps last_seen/last_seen_live current. The periodic check_processing_services_online task
will mark services offline if this heartbeat stops arriving within PROCESSING_SERVICE_LAST_SEEN_MAX.

IMPORTANT: This marks ALL async services on the pipeline within this project as live, not just
the specific service that made the request. If multiple async services share the same pipeline
within a project, a single worker polling will keep all of them appearing online.
Once application-token auth is available (PR #1117), this should be scoped to the individual
calling service instead.
Enqueue a fire-and-forget heartbeat for async (pull-mode) processing services
linked to the job's pipeline.

A Redis cache gate skips the dispatch when a heartbeat for the same
(pipeline, project) has already fired within HEARTBEAT_THROTTLE_SECONDS,
so under concurrent polling we avoid broker + task churn. The Celery task
keeps the DB write off the HTTP request path.

Cache key scope: currently `heartbeat:pipeline:<pipeline_id>:project:<project_id>`
because we cannot yet identify the specific calling service. Once
application-token auth lands (PR #1117), the key should become
`heartbeat:service:<service_id>` so each service gets its own throttle
window and one service's poll does not suppress another's heartbeat.
"""
import datetime

if not job.pipeline_id:
return
job.pipeline.processing_services.async_services().filter(projects=job.project_id).update(
last_seen=datetime.datetime.now(),
last_seen_live=True,
)
cache_key = f"heartbeat:pipeline:{job.pipeline_id}:project:{job.project_id}"
if not cache.add(cache_key, 1, timeout=HEARTBEAT_THROTTLE_SECONDS):
return
try:
update_pipeline_pull_services_seen.delay(job.pk)
except (kombu.exceptions.KombuError, ConnectionError, OSError) as exc:
msg = f"Failed to enqueue non-critical pipeline heartbeat for job {job.pk}: {exc}"
logger.warning(msg)
job.logger.warning(msg)


class JobFilterSet(filters.FilterSet):
Expand Down
Loading