Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion ami/jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pydantic
from celery import uuid
from celery.result import AsyncResult
from django.conf import settings
from django.db import models, transaction
from django.utils.text import slugify
from django_pydantic_field import SchemaField
Expand Down Expand Up @@ -333,9 +334,21 @@ def __init__(self, job: "Job", *args, **kwargs):
super().__init__(*args, **kwargs)

def emit(self, record: logging.LogRecord):
# Log to the current app logger
# Log to the current app logger (container stdout).
logger.log(record.levelno, self.format(record))

# Gated by ``JOB_LOG_PERSIST_ENABLED`` (default True). Persisting every
# log line to ``jobs_job.logs`` becomes a row-lock contention point
# under concurrent async_api load — each call triggers
# ``UPDATE jobs_job SET logs = ...`` on the shared job row, and inside
# ``ATOMIC_REQUESTS`` a single batched ``/result`` POST stacks N such
# UPDATEs in one tx, blocking every ML worker on the same row for the
# duration of the request. Deployments hitting that pattern can set the
# flag to False to short-circuit here until PR #1259 lands an
# append-only ``JobLog`` child table. See issue #1256.
if not getattr(settings, "JOB_LOG_PERSIST_ENABLED", True):
return

# Write to the logs field on the job instance.
# Refresh from DB first to reduce the window for concurrent overwrites — each
# worker holds its own stale in-memory copy of `logs`, so without a refresh the
Expand Down
81 changes: 79 additions & 2 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING

from asgiref.sync import async_to_sync, sync_to_async
from cachalot.api import cachalot_disabled
from celery.signals import task_failure, task_postrun, task_prerun
from django.db import transaction
from redis.exceptions import RedisError
Expand Down Expand Up @@ -76,6 +77,56 @@ def update_pipeline_pull_services_seen(job_id: int) -> None:
)


@celery_app.task(
soft_time_limit=10,
time_limit=15,
ignore_result=True,
)
def update_async_services_seen_for_pipelines(pipeline_slugs: list[str]) -> None:
"""
Heartbeat for idle worker polls on
``GET /api/v2/jobs/?pipeline__slug__in=...&ids_only=1``.

The ADC worker sends pipeline slugs but no project_id (one worker may serve
Comment thread
mihow marked this conversation as resolved.
pipelines across many projects), so scope the heartbeat by the pipelines it
asked about. Marks every async ProcessingService linked to any of those
pipelines as seen.
"""
from ami.ml.models import ProcessingService # avoid circular import

if not pipeline_slugs:
return

ProcessingService.objects.async_services().filter(
pipelines__slug__in=pipeline_slugs,
).distinct().update(
last_seen=datetime.datetime.now(),
last_seen_live=True,
)


@celery_app.task(
soft_time_limit=10,
time_limit=15,
ignore_result=True,
)
def update_async_services_seen_for_project(project_id: int) -> None:
"""
Heartbeat for idle worker polls on ``GET /api/v2/jobs/?ids_only=1``.

Unlike ``update_pipeline_pull_services_seen`` — which is pipeline-scoped and
Comment thread
mihow marked this conversation as resolved.
Outdated
only fires when a worker hits /tasks/ or /result/ for an active job — this
marks every async processing service attached to the polling project as
seen. The list endpoint has no pipeline context, so scope is the project.
"""
from ami.ml.models import ProcessingService # avoid circular import

ProcessingService.objects.async_services().filter(projects=project_id).update(
last_seen=datetime.datetime.now(),
last_seen_live=True,
)


@celery_app.task(bind=True, soft_time_limit=default_soft_time_limit, time_limit=default_time_limit)
def run_job(self, job_id: int) -> None:
from ami.jobs.models import Job
Expand Down Expand Up @@ -158,6 +209,15 @@ def _log_worker_availability(job) -> None:
soft_time_limit=300, # 5 minutes
time_limit=360, # 6 minutes
)
# Disable cachalot cache invalidation for this task. Each call writes
# Detection/Classification rows and UPDATEs jobs_job; under concurrent
# async_api load, cachalot's post-write invalidation added ~2.5s/task
# (measured on demo, issue #1256 Path 4). This is a pure write path —
# nothing inside benefits from the query cache — so skipping invalidation
# is strictly a throughput win. Celery task decorator stack order matters:
# @celery_app.task wraps the cachalot-wrapped function, so Celery sees the
# cachalot context manager enter/exit on every task execution.
@cachalot_disabled()
def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_subject: str) -> None:
Comment thread
mihow marked this conversation as resolved.
"""
Process a single pipeline result asynchronously.
Expand Down Expand Up @@ -489,8 +549,19 @@ def _update_job_progress(
) -> None:
from ami.jobs.models import Job, JobState # avoid circular import

# NOTE: Previously this used `select_for_update()` inside `transaction.atomic()`
# to serialize concurrent progress updates for the same job. Under concurrent
# async_api result processing that serialization became a bottleneck: every
# ML result task queued a contending exclusive lock on the `jobs_job` row,
# stacking behind gunicorn view threads also holding the row under
# ATOMIC_REQUESTS. The `max()` guard below still prevents progress regression
# between concurrent workers; the trade-off is that accumulated counts
# (detections/classifications/captures) can drift by one batch under race —
# cosmetic only, since the underlying `Detection`/`Classification` rows are
# written authoritatively by `save_results` before this function runs.
# See issue #1256 and PR #1261.
with transaction.atomic():
job = Job.objects.select_for_update().get(pk=job_id)
job = Job.objects.get(pk=job_id)
Comment thread
mihow marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.

# For results stage, accumulate detections/classifications/captures counts
if stage == "results":
Expand Down Expand Up @@ -557,7 +628,13 @@ def _update_job_progress(
job.progress.summary.status = complete_state
job.finished_at = datetime.datetime.now() # Use naive datetime in local time
job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%")
job.save()
# Narrow the write to the fields we actually mutated. Without this, a full
# save() would also overwrite `updated_at`, `logs`, and any other field on
Comment thread
mihow marked this conversation as resolved.
Outdated
# the instance fetched at the top of this block — so a concurrent worker's
# append to `progress.errors` (via `_reconcile_lost_images`) or log line
# (via JobLogHandler) could be clobbered by a stale read-modify-write.
# See PR #1261 review feedback.
job.save(update_fields=["progress", "status", "finished_at", "updated_at"])
try:
_log_job_throughput(job, stage)
except Exception as e:
Expand Down
159 changes: 158 additions & 1 deletion ami/jobs/tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,11 @@ def test_list_jobs_with_ids_only(self):
self._create_job("Test job 3", start_now=False)

self.client.force_authenticate(user=self.user)
jobs_list_url = reverse_with_params("api:job-list", params={"project_id": self.project.pk, "ids_only": True})
# Pass an explicit limit to override the pop()-style default (see test_list_jobs_ids_only_pops_one below).
jobs_list_url = reverse_with_params(
"api:job-list",
params={"project_id": self.project.pk, "ids_only": True, "limit": 10},
)
resp = self.client.get(jobs_list_url)

self.assertEqual(resp.status_code, 200)
Expand All @@ -388,6 +392,21 @@ def test_list_jobs_with_ids_only(self):
# Verify we don't get the full results structure
self.assertNotIn("details", data["results"][0])

def test_list_jobs_ids_only_pops_one(self):
"""`?ids_only=1` without an explicit limit returns one job (pop()-style handoff)."""
self._create_job("Test job 2", start_now=False)
self._create_job("Test job 3", start_now=False)

self.client.force_authenticate(user=self.user)
jobs_list_url = reverse_with_params("api:job-list", params={"project_id": self.project.pk, "ids_only": True})
resp = self.client.get(jobs_list_url)

self.assertEqual(resp.status_code, 200)
data = resp.json()
self.assertEqual(data["count"], 3)
self.assertEqual(len(data["results"]), 1)
self.assertIsInstance(data["results"][0]["id"], int)

def test_list_jobs_with_incomplete_only(self):
"""Test the incomplete_only parameter filters jobs correctly."""
# Create jobs via API
Expand Down Expand Up @@ -1157,3 +1176,141 @@ def test_view_gate_suppresses_redundant_dispatches(self):
_mark_pipeline_pull_services_seen(self.job)

self.assertEqual(mock_delay.call_count, 1)


class TestListEndpointHeartbeat(APITestCase):
"""
Unit tests for _mark_async_services_seen_for_project and the list endpoint's
heartbeat dispatch on ``?ids_only=1`` polls.
"""

def setUp(self):
from django.core.cache import cache

cache.clear()

self.project = Project.objects.create(name="List Heartbeat Project")
self.service = ProcessingService.objects.create(
name="List Heartbeat Worker",
endpoint_url=None,
)
self.service.projects.add(self.project)

self.user = User.objects.create_user(email="list-heartbeat@example.com", is_superuser=True, is_active=True)
self.client.force_authenticate(user=self.user)

def test_list_with_ids_only_dispatches_heartbeat(self):
from unittest.mock import patch

list_url = reverse_with_params("api:job-list", params={"project_id": self.project.pk, "ids_only": True})
with patch("ami.jobs.views.update_async_services_seen_for_project.delay") as mock_delay:
resp = self.client.get(list_url)

self.assertEqual(resp.status_code, 200)
mock_delay.assert_called_once_with(self.project.pk)

def test_list_without_ids_only_does_not_dispatch_heartbeat(self):
from unittest.mock import patch

list_url = reverse_with_params("api:job-list", params={"project_id": self.project.pk})
with patch("ami.jobs.views.update_async_services_seen_for_project.delay") as mock_delay:
resp = self.client.get(list_url)

self.assertEqual(resp.status_code, 200)
mock_delay.assert_not_called()

def test_list_heartbeat_tolerates_dispatch_failure(self):
"""Broker unavailability on heartbeat enqueue must not break the list response."""
from unittest.mock import patch

from kombu.exceptions import OperationalError

list_url = reverse_with_params("api:job-list", params={"project_id": self.project.pk, "ids_only": True})
with patch(
"ami.jobs.views.update_async_services_seen_for_project.delay",
side_effect=OperationalError("broker unavailable"),
):
resp = self.client.get(list_url)

self.assertEqual(resp.status_code, 200)

def test_view_gate_suppresses_redundant_list_dispatches(self):
"""Rapid repeated list polls should dispatch at most once per throttle window."""
from unittest.mock import patch

from ami.jobs.views import _mark_async_services_seen_for_project

with patch("ami.jobs.views.update_async_services_seen_for_project.delay") as mock_delay:
for _ in range(5):
_mark_async_services_seen_for_project(self.project.pk)

self.assertEqual(mock_delay.call_count, 1)

def test_list_with_pipeline_slugs_no_project_dispatches_heartbeat(self):
"""Real ADC worker shape: ?ids_only=1&pipeline__slug__in=... with no project_id."""
from unittest.mock import patch

pipeline = Pipeline.objects.create(name="Heartbeat Pipeline", slug="heartbeat-pipeline")
self.service.pipelines.add(pipeline)

list_url = reverse_with_params(
"api:job-list",
params={"ids_only": True, "pipeline__slug__in": "heartbeat-pipeline"},
)
with patch("ami.jobs.views.update_async_services_seen_for_pipelines.delay") as mock_delay:
resp = self.client.get(list_url)

self.assertEqual(resp.status_code, 200)
mock_delay.assert_called_once_with(["heartbeat-pipeline"])

def test_task_updates_services_via_pipeline_slug(self):
"""The pipeline-slug celery task marks matching async services live."""
import datetime

from ami.jobs.tasks import update_async_services_seen_for_pipelines

pipeline = Pipeline.objects.create(name="Slug Pipeline", slug="slug-pipeline")
self.service.pipelines.add(pipeline)
unrelated = ProcessingService.objects.create(name="Unrelated Async", endpoint_url=None)
unrelated_last_seen_before = unrelated.last_seen

before = datetime.datetime.now() - datetime.timedelta(seconds=1)
update_async_services_seen_for_pipelines(["slug-pipeline"])

self.service.refresh_from_db()
unrelated.refresh_from_db()

self.assertTrue(self.service.last_seen_live)
self.assertIsNotNone(self.service.last_seen)
self.assertGreaterEqual(self.service.last_seen, before)
self.assertEqual(unrelated.last_seen, unrelated_last_seen_before)

def test_task_updates_all_project_async_services(self):
"""The celery task marks every async service on the project live."""
import datetime

from ami.jobs.tasks import update_async_services_seen_for_project

other_async = ProcessingService.objects.create(name="Other Async", endpoint_url=None)
other_async.projects.add(self.project)
sync_service = ProcessingService.objects.create(
name="Sync Service", endpoint_url="http://nonexistent-host:9999"
)
sync_service.projects.add(self.project)
sync_last_seen_before = ProcessingService.objects.get(pk=sync_service.pk).last_seen

before = datetime.datetime.now() - datetime.timedelta(seconds=1)
update_async_services_seen_for_project(self.project.pk)

self.service.refresh_from_db()
other_async.refresh_from_db()
sync_service.refresh_from_db()

self.assertTrue(self.service.last_seen_live)
self.assertIsNotNone(self.service.last_seen)
self.assertGreaterEqual(self.service.last_seen, before)
self.assertTrue(other_async.last_seen_live)
# Sync services (with endpoint URL) are not touched by this task — last_seen
# may be set by the creation-time get_status() ping, but should be unchanged
# after the task runs.
self.assertEqual(sync_service.last_seen, sync_last_seen_before)
Loading
Loading