diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 183cd5186..629c95828 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -34,6 +34,47 @@ # "nobody's listening" signal. WORKER_AVAILABILITY_ONLINE_CUTOFF = datetime.timedelta(minutes=5) +# Minimum interval between heartbeat dispatches for a given (pipeline, project). +# The view-level Redis cache gate uses this window to skip .delay() under +# concurrent polling; the task itself does no throttling. +HEARTBEAT_THROTTLE_SECONDS = 30 + + +@celery_app.task( + soft_time_limit=10, + time_limit=15, + ignore_result=True, + # No retries — a missed heartbeat is benign; retrying adds load for no gain. +) +def update_pipeline_pull_services_seen(job_id: int) -> None: + """ + Fire-and-forget heartbeat task: record last_seen/last_seen_live for async + (pull-mode) processing services linked to a job's pipeline. + + Throttling lives in the view (Redis cache gate over HEARTBEAT_THROTTLE_SECONDS), + so this task is dispatched at most once per (pipeline, project) per window + and can just write. + + Scope: marks ALL async services on the pipeline within this project as live, + not just the specific service that made the request. Once application-token + auth is available (PR #1117), this should be scoped to the individual + calling service instead. + """ + from ami.jobs.models import Job # avoid circular import + + try: + job = Job.objects.select_related("pipeline").get(pk=job_id) + except Job.DoesNotExist: + return + + if not job.pipeline_id: + return + + job.pipeline.processing_services.async_services().filter(projects=job.project_id).update( + last_seen=datetime.datetime.now(), + last_seen_live=True, + ) + @celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit) def run_job(self, job_id: int) -> None: diff --git a/ami/jobs/tests/test_jobs.py b/ami/jobs/tests/test_jobs.py index 847a61f7e..90d1f6baa 100644 --- a/ami/jobs/tests/test_jobs.py +++ b/ami/jobs/tests/test_jobs.py @@ -11,6 +11,7 @@ from ami.jobs.models import Job, JobDispatchMode, JobProgress, JobState, MLJob, SourceImageCollectionPopulateJob from ami.main.models import Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline +from ami.ml.models.processing_service import ProcessingService from ami.ml.orchestration.jobs import queue_images_to_nats from ami.users.models import User @@ -1016,3 +1017,143 @@ def test_tasks_endpoint_rejects_non_async_jobs(self): resp = self.client.post(tasks_url, {"batch_size": 1}, format="json") self.assertEqual(resp.status_code, 400) self.assertIn("async_api", resp.json()[0].lower()) + + +class TestPipelineHeartbeatTask(APITestCase): + """ + Unit tests for update_pipeline_pull_services_seen and the view-level + _mark_pipeline_pull_services_seen fire-and-forget dispatch. + """ + + def setUp(self): + from django.core.cache import cache + + # Cache-based gate in _mark_pipeline_pull_services_seen would otherwise + # carry over between tests and suppress the .delay() we want to assert. + cache.clear() + + self.project = Project.objects.create(name="Heartbeat Test Project") + self.pipeline = Pipeline.objects.create(name="Heartbeat Pipeline", slug="heartbeat-pipeline") + self.pipeline.projects.add(self.project) + self.collection = SourceImageCollection.objects.create(name="HB Collection", project=self.project) + self.job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Heartbeat Test Job", + pipeline=self.pipeline, + source_image_collection=self.collection, + dispatch_mode=JobDispatchMode.ASYNC_API, + ) + self.service = ProcessingService.objects.create( + name="Heartbeat Worker", + endpoint_url=None, # None = pull-mode / async service + ) + self.service.pipelines.add(self.pipeline) + self.service.projects.add(self.project) + + def test_tasks_endpoint_dispatches_heartbeat_task(self): + """The /tasks endpoint calls update_pipeline_pull_services_seen.delay(), not the DB directly.""" + from unittest.mock import patch + + job = self.job + job.status = JobState.STARTED + job.save(update_fields=["status"]) + + images = [ + SourceImage.objects.create( + path=f"hb_tasks_{i}.jpg", + public_base_url="http://example.com", + project=self.project, + ) + for i in range(2) + ] + queue_images_to_nats(job, images) + + user = User.objects.create_user(email="hbtest@example.com", is_superuser=True, is_active=True) + self.client.force_authenticate(user=user) + + with patch("ami.jobs.views.update_pipeline_pull_services_seen.delay") as mock_delay: + tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk}) + resp = self.client.post(tasks_url, {"batch_size": 1}, format="json") + + self.assertEqual(resp.status_code, 200) + mock_delay.assert_called_once_with(job.pk) + + def test_result_endpoint_dispatches_heartbeat_task(self): + """The /result endpoint calls update_pipeline_pull_services_seen.delay(), not the DB directly.""" + from unittest.mock import MagicMock, patch + + user = User.objects.create_user(email="hbresult@example.com", is_superuser=True, is_active=True) + self.client.force_authenticate(user=user) + + result_data = { + "results": [ + { + "reply_subject": "test.reply.hb", + "result": { + "pipeline": "heartbeat-pipeline", + "algorithms": {}, + "total_time": 0.1, + "source_images": [], + "detections": [], + "errors": None, + }, + } + ] + } + + mock_async_result = MagicMock() + mock_async_result.id = "hb-task-id" + with ( + patch("ami.jobs.views.process_nats_pipeline_result.delay", return_value=mock_async_result), + patch("ami.jobs.views.update_pipeline_pull_services_seen.delay") as mock_delay, + ): + result_url = reverse_with_params( + "api:job-result", args=[self.job.pk], params={"project_id": self.project.pk} + ) + resp = self.client.post(result_url, result_data, format="json") + + self.assertEqual(resp.status_code, 200) + mock_delay.assert_called_once_with(self.job.pk) + + def test_tasks_endpoint_tolerates_heartbeat_dispatch_failure(self): + """Heartbeat enqueue errors should not fail the /tasks response.""" + from unittest.mock import patch + + from kombu.exceptions import OperationalError + + job = self.job + job.status = JobState.STARTED + job.save(update_fields=["status"]) + + image = SourceImage.objects.create( + path="hb_tasks_broker.jpg", + public_base_url="http://example.com", + project=self.project, + ) + queue_images_to_nats(job, [image]) + + user = User.objects.create_user(email="hbbroker@example.com", is_superuser=True, is_active=True) + self.client.force_authenticate(user=user) + + with patch( + "ami.jobs.views.update_pipeline_pull_services_seen.delay", + side_effect=OperationalError("broker unavailable"), + ): + tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk}) + resp = self.client.post(tasks_url, {"batch_size": 1}, format="json") + + self.assertEqual(resp.status_code, 200) + self.assertEqual(len(resp.json()["tasks"]), 1) + + def test_view_gate_suppresses_redundant_dispatches(self): + """Rapid repeated calls to _mark_pipeline_pull_services_seen should only enqueue once per window.""" + from unittest.mock import patch + + from ami.jobs.views import _mark_pipeline_pull_services_seen + + with patch("ami.jobs.views.update_pipeline_pull_services_seen.delay") as mock_delay: + for _ in range(5): + _mark_pipeline_pull_services_seen(self.job) + + self.assertEqual(mock_delay.call_count, 1) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 47c2461b9..ec9d64481 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -4,6 +4,7 @@ import kombu.exceptions import nats.errors from asgiref.sync import async_to_sync +from django.core.cache import cache from django.db.models import Q from django.db.models.query import QuerySet from django.forms import IntegerField @@ -24,7 +25,7 @@ MLJobTasksRequestSerializer, MLJobTasksResponseSerializer, ) -from ami.jobs.tasks import process_nats_pipeline_result +from ami.jobs.tasks import HEARTBEAT_THROTTLE_SECONDS, process_nats_pipeline_result, update_pipeline_pull_services_seen from ami.main.api.schemas import project_id_doc_param from ami.main.api.views import DefaultViewSet from ami.utils.fields import url_boolean_param @@ -52,26 +53,31 @@ def _actor_log_context(request) -> tuple[str, str | None]: def _mark_pipeline_pull_services_seen(job: "Job") -> None: """ - Record a heartbeat for async (pull-mode) processing services linked to the job's pipeline. - - Called on every task-fetch and result-submit request so that the worker's polling activity - keeps last_seen/last_seen_live current. The periodic check_processing_services_online task - will mark services offline if this heartbeat stops arriving within PROCESSING_SERVICE_LAST_SEEN_MAX. - - IMPORTANT: This marks ALL async services on the pipeline within this project as live, not just - the specific service that made the request. If multiple async services share the same pipeline - within a project, a single worker polling will keep all of them appearing online. - Once application-token auth is available (PR #1117), this should be scoped to the individual - calling service instead. + Enqueue a fire-and-forget heartbeat for async (pull-mode) processing services + linked to the job's pipeline. + + A Redis cache gate skips the dispatch when a heartbeat for the same + (pipeline, project) has already fired within HEARTBEAT_THROTTLE_SECONDS, + so under concurrent polling we avoid broker + task churn. The Celery task + keeps the DB write off the HTTP request path. + + Cache key scope: currently `heartbeat:pipeline::project:` + because we cannot yet identify the specific calling service. Once + application-token auth lands (PR #1117), the key should become + `heartbeat:service:` so each service gets its own throttle + window and one service's poll does not suppress another's heartbeat. """ - import datetime - if not job.pipeline_id: return - job.pipeline.processing_services.async_services().filter(projects=job.project_id).update( - last_seen=datetime.datetime.now(), - last_seen_live=True, - ) + cache_key = f"heartbeat:pipeline:{job.pipeline_id}:project:{job.project_id}" + if not cache.add(cache_key, 1, timeout=HEARTBEAT_THROTTLE_SECONDS): + return + try: + update_pipeline_pull_services_seen.delay(job.pk) + except (kombu.exceptions.KombuError, ConnectionError, OSError) as exc: + msg = f"Failed to enqueue non-critical pipeline heartbeat for job {job.pk}: {exc}" + logger.warning(msg) + job.logger.warning(msg) class JobFilterSet(filters.FilterSet):