Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,57 @@
# "nobody's listening" signal.
WORKER_AVAILABILITY_ONLINE_CUTOFF = datetime.timedelta(minutes=5)

# Minimum interval between heartbeat DB writes for a given job.
# PROCESSING_SERVICE_LAST_SEEN_MAX is 60s; writing at most once per 30s keeps
# last_seen current without hammering the same rows on every concurrent request.
HEARTBEAT_THROTTLE_SECONDS = 30
Comment thread
mihow marked this conversation as resolved.
Outdated


@celery_app.task(
soft_time_limit=10,
time_limit=15,
Comment thread
mihow marked this conversation as resolved.
# No retries — a missed heartbeat is benign; retrying adds load for no gain.
)
def update_pipeline_pull_services_seen(job_id: int) -> None:
"""
Fire-and-forget heartbeat task: record last_seen/last_seen_live for async
(pull-mode) processing services linked to a job's pipeline.

Called via .delay() from the tasks and result view endpoints so the HTTP
request is never blocked on this DB write.

Throttle: skips the UPDATE if all matching services were seen within
HEARTBEAT_THROTTLE_SECONDS, cutting write rate under concurrent requests by
orders of magnitude while keeping last_seen fresh relative to the 60s
PROCESSING_SERVICE_LAST_SEEN_MAX threshold.

Scope: marks ALL async services on the pipeline within this project as live,
not just the specific service that made the request. Once application-token
auth is available (PR #1117), this should be scoped to the individual
calling service instead.
"""
from ami.jobs.models import Job # avoid circular import

try:
job = Job.objects.select_related("pipeline").get(pk=job_id)
except Job.DoesNotExist:
return

if not job.pipeline_id:
return

now = datetime.datetime.now()
throttle_cutoff = now - datetime.timedelta(seconds=HEARTBEAT_THROTTLE_SECONDS)

services_qs = job.pipeline.processing_services.async_services().filter(projects=job.project_id)

# Cheap read: skip the UPDATE if every matching service was seen recently.
recent_seen = services_qs.values_list("last_seen", flat=True)
if recent_seen and all(ts is not None and ts >= throttle_cutoff for ts in recent_seen):
return
Comment thread
mihow marked this conversation as resolved.
Outdated

services_qs.update(last_seen=now, last_seen_live=True)

Comment thread
mihow marked this conversation as resolved.
Outdated

@celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit)
def run_job(self, job_id: int) -> None:
Expand Down
164 changes: 164 additions & 0 deletions ami/jobs/tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ami.jobs.models import Job, JobDispatchMode, JobProgress, JobState, MLJob, SourceImageCollectionPopulateJob
from ami.main.models import Project, SourceImage, SourceImageCollection
from ami.ml.models import Pipeline
from ami.ml.models.processing_service import ProcessingService
from ami.ml.orchestration.jobs import queue_images_to_nats
from ami.users.models import User

Expand Down Expand Up @@ -1016,3 +1017,166 @@ def test_tasks_endpoint_rejects_non_async_jobs(self):
resp = self.client.post(tasks_url, {"batch_size": 1}, format="json")
self.assertEqual(resp.status_code, 400)
self.assertIn("async_api", resp.json()[0].lower())


class TestPipelineHeartbeatTask(APITestCase):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a of tests for a "nice to have" feature that is slowing down our required core features. What is necessary?

"""
Unit tests for update_pipeline_pull_services_seen and the view-level
_mark_pipeline_pull_services_seen fire-and-forget dispatch.
"""

def setUp(self):
self.project = Project.objects.create(name="Heartbeat Test Project")
self.pipeline = Pipeline.objects.create(name="Heartbeat Pipeline", slug="heartbeat-pipeline")
self.pipeline.projects.add(self.project)
self.collection = SourceImageCollection.objects.create(name="HB Collection", project=self.project)
self.job = Job.objects.create(
job_type_key=MLJob.key,
project=self.project,
name="Heartbeat Test Job",
pipeline=self.pipeline,
source_image_collection=self.collection,
dispatch_mode=JobDispatchMode.ASYNC_API,
)
self.service = ProcessingService.objects.create(
name="Heartbeat Worker",
endpoint_url=None, # None = pull-mode / async service
)
self.service.pipelines.add(self.pipeline)
self.service.projects.add(self.project)

def test_tasks_endpoint_dispatches_heartbeat_task(self):
"""The /tasks endpoint calls update_pipeline_pull_services_seen.delay(), not the DB directly."""
from unittest.mock import patch

job = self.job
job.status = JobState.STARTED
job.save(update_fields=["status"])

images = [
SourceImage.objects.create(
path=f"hb_tasks_{i}.jpg",
public_base_url="http://example.com",
project=self.project,
)
for i in range(2)
]
queue_images_to_nats(job, images)

user = User.objects.create_user(email="hbtest@example.com", is_superuser=True, is_active=True)
self.client.force_authenticate(user=user)

with patch("ami.jobs.views.update_pipeline_pull_services_seen.delay") as mock_delay:
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch_size": 1}, format="json")

self.assertEqual(resp.status_code, 200)
mock_delay.assert_called_once_with(job.pk)

def test_result_endpoint_dispatches_heartbeat_task(self):
"""The /result endpoint calls update_pipeline_pull_services_seen.delay(), not the DB directly."""
from unittest.mock import MagicMock, patch

user = User.objects.create_user(email="hbresult@example.com", is_superuser=True, is_active=True)
self.client.force_authenticate(user=user)

result_data = {
"results": [
{
"reply_subject": "test.reply.hb",
"result": {
"pipeline": "heartbeat-pipeline",
"algorithms": {},
"total_time": 0.1,
"source_images": [],
"detections": [],
"errors": None,
},
}
]
}

mock_async_result = MagicMock()
mock_async_result.id = "hb-task-id"
with (
patch("ami.jobs.views.process_nats_pipeline_result.delay", return_value=mock_async_result),
patch("ami.jobs.views.update_pipeline_pull_services_seen.delay") as mock_delay,
):
result_url = reverse_with_params(
"api:job-result", args=[self.job.pk], params={"project_id": self.project.pk}
)
resp = self.client.post(result_url, result_data, format="json")

self.assertEqual(resp.status_code, 200)
mock_delay.assert_called_once_with(self.job.pk)

def test_heartbeat_task_updates_last_seen_when_stale(self):
"""update_pipeline_pull_services_seen writes last_seen when the service is stale."""
import datetime

from ami.jobs.tasks import update_pipeline_pull_services_seen

# Set last_seen to well past the throttle window
old_time = datetime.datetime.now() - datetime.timedelta(minutes=5)
self.service.last_seen = old_time
self.service.last_seen_live = False
self.service.save(update_fields=["last_seen", "last_seen_live"])

update_pipeline_pull_services_seen(self.job.pk)

self.service.refresh_from_db()
self.assertTrue(self.service.last_seen_live)
self.assertGreater(self.service.last_seen, old_time)

def test_heartbeat_task_skips_update_when_recent(self):
"""update_pipeline_pull_services_seen skips the UPDATE when last_seen is within the throttle window."""
import datetime

from ami.jobs.tasks import update_pipeline_pull_services_seen

# Set last_seen to just now — well inside the 30s throttle window
recent_time = datetime.datetime.now() - datetime.timedelta(seconds=5)
self.service.last_seen = recent_time
self.service.last_seen_live = True
self.service.save(update_fields=["last_seen", "last_seen_live"])

update_pipeline_pull_services_seen(self.job.pk)

self.service.refresh_from_db()
# last_seen should not have advanced significantly (throttle skipped the UPDATE)
self.assertAlmostEqual(
self.service.last_seen.timestamp(),
recent_time.timestamp(),
delta=1.0,
)

def test_heartbeat_task_no_op_for_missing_job(self):
"""update_pipeline_pull_services_seen silently returns when job_id does not exist."""
from ami.jobs.tasks import update_pipeline_pull_services_seen

# Should not raise
update_pipeline_pull_services_seen(job_id=999999)

def test_heartbeat_task_no_op_for_job_without_pipeline(self):
"""update_pipeline_pull_services_seen returns early when job has no pipeline."""
import datetime

from ami.jobs.tasks import update_pipeline_pull_services_seen

job_no_pipeline = Job.objects.create(
job_type_key=MLJob.key,
project=self.project,
name="No-pipeline job",
source_image_collection=self.collection,
dispatch_mode=JobDispatchMode.ASYNC_API,
)

old_time = datetime.datetime.now() - datetime.timedelta(minutes=10)
self.service.last_seen = old_time
self.service.save(update_fields=["last_seen"])

update_pipeline_pull_services_seen(job_no_pipeline.pk)

# Service last_seen should be unchanged because the task returned early
self.service.refresh_from_db()
self.assertAlmostEqual(self.service.last_seen.timestamp(), old_time.timestamp(), delta=1.0)
32 changes: 14 additions & 18 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
MLJobTasksRequestSerializer,
MLJobTasksResponseSerializer,
)
from ami.jobs.tasks import process_nats_pipeline_result
from ami.jobs.tasks import process_nats_pipeline_result, update_pipeline_pull_services_seen
from ami.main.api.schemas import project_id_doc_param
from ami.main.api.views import DefaultViewSet
from ami.utils.fields import url_boolean_param
Expand Down Expand Up @@ -52,26 +52,22 @@ def _actor_log_context(request) -> tuple[str, str | None]:

def _mark_pipeline_pull_services_seen(job: "Job") -> None:
"""
Record a heartbeat for async (pull-mode) processing services linked to the job's pipeline.

Called on every task-fetch and result-submit request so that the worker's polling activity
keeps last_seen/last_seen_live current. The periodic check_processing_services_online task
will mark services offline if this heartbeat stops arriving within PROCESSING_SERVICE_LAST_SEEN_MAX.

IMPORTANT: This marks ALL async services on the pipeline within this project as live, not just
the specific service that made the request. If multiple async services share the same pipeline
within a project, a single worker polling will keep all of them appearing online.
Once application-token auth is available (PR #1117), this should be scoped to the individual
calling service instead.
Enqueue a fire-and-forget heartbeat for async (pull-mode) processing services
linked to the job's pipeline.

Dispatches update_pipeline_pull_services_seen via Celery .delay() so the view
is never blocked on the DB write. The task throttles writes to at most once per
~30 seconds per job, keeping last_seen current relative to the 60s
PROCESSING_SERVICE_LAST_SEEN_MAX threshold without hammering the same rows on
every concurrent task-fetch or result-submit request.
Comment thread
mihow marked this conversation as resolved.
Outdated

Per-service scoping is not yet possible — marks ALL async services on the
pipeline within this project as live. Once application-token auth lands
(PR #1117) this can be scoped to the individual calling service.
"""
import datetime

if not job.pipeline_id:
return
job.pipeline.processing_services.async_services().filter(projects=job.project_id).update(
last_seen=datetime.datetime.now(),
last_seen_live=True,
)
update_pipeline_pull_services_seen.delay(job.pk)
Comment thread
mihow marked this conversation as resolved.
Outdated


class JobFilterSet(filters.FilterSet):
Expand Down
Loading