Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,65 @@ def _get_current_counts_from_job_progress(job, stage: str) -> tuple[int, int, in
return 0, 0, 0


def _format_elapsed(seconds: float) -> str:
"""Render a duration as `Hh Mm Ss` (hours omitted when zero)."""
total = max(0, int(seconds))
h, rem = divmod(total, 3600)
m, s = divmod(rem, 60)
if h > 0:
return f"{h}h {m:02d}m {s:02d}s"
return f"{m}m {s:02d}s"


def _log_job_throughput(job, stage: str) -> None:
"""
Emit a per-job throughput/ETA line so operators can distinguish stalled-vs-slow
vs healthy-but-throttled jobs at a glance in the per-job log view.

Intentionally a plain division over total elapsed time, not a rolling-window
estimate or forecast — accurate enough to spot a stall, cheap to compute, and
easy to interpret from a single log line.
"""
if stage not in ("process", "results"):
return
if not job.started_at:
return
elapsed_seconds = (datetime.datetime.now() - job.started_at).total_seconds()
elapsed_minutes = elapsed_seconds / 60.0
if elapsed_minutes < 0.05:
# Ratio over <3s of elapsed time is noise, not signal.
return

# The process stage holds the authoritative processed/remaining counts
# (results stage only tracks detection/classification/capture counts).
try:
process_stage = job.progress.get_stage("process")
except (ValueError, AttributeError):
return

processed = 0
remaining = 0
for param in getattr(process_stage, "params", []) or []:
if param.key == "processed":
processed = param.value or 0
elif param.key == "remaining":
remaining = param.value or 0
total = processed + remaining

if processed == 0:
rate_str = "rate=0.0 imgs/min, ETA=unknown"
else:
rate = processed / elapsed_minutes
remaining_imgs = max(0, total - processed)
eta_seconds = (remaining_imgs / rate) * 60.0 if rate > 0 else 0.0
rate_str = f"rate={rate:.1f} imgs/min, ETA={_format_elapsed(eta_seconds)}"

job.logger.info(
f"Job {job.pk} throughput: elapsed={_format_elapsed(elapsed_seconds)}, "
f"processed={processed}/{total}, {rate_str}"
)


def _update_job_progress(
job_id: int, stage: str, progress_percentage: float, complete_state: "JobState", **state_params
) -> None:
Expand Down Expand Up @@ -412,6 +471,10 @@ def _update_job_progress(
job.finished_at = datetime.datetime.now() # Use naive datetime in local time
job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%")
job.save()
try:
_log_job_throughput(job, stage)
except Exception as e:
logger.warning("Throughput log failed for job %s: %s", job_id, e)

# Clean up async resources for completed jobs that use NATS/Redis
if job.progress.is_complete():
Expand Down Expand Up @@ -714,7 +777,10 @@ def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs):
# SUCCESS should only be set when all stages are actually complete
# This prevents premature SUCCESS when async workers are still processing
if state == JobState.SUCCESS and not job.progress.is_complete():
job.logger.info(
# DEBUG — fires on every async_api task_postrun (Celery task ends when
# images are queued; async workers drive the actual stages afterward).
# Always true under normal operation, so not informative at INFO.
job.logger.debug(
f"Job {job.pk} task completed but stages not finished - " "deferring SUCCESS status to progress handler"
)
return
Expand Down
271 changes: 271 additions & 0 deletions ami/jobs/tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,277 @@ def test_result_endpoint_validation(self):
resp = self.client.post(result_url, bare_list, format="json")
self.assertEqual(resp.status_code, 400)

def test_tasks_endpoint_logs_fetch_to_job_logger(self):
"""Successful task-fetch lands a 'Tasks fetched' line on the per-job logger."""
pipeline = self._create_pipeline()
job = self._create_ml_job("Job for fetch-logging test", pipeline)
job.dispatch_mode = JobDispatchMode.ASYNC_API
job.status = JobState.STARTED
job.save(update_fields=["dispatch_mode", "status"])
images = [
SourceImage.objects.create(
path=f"fetchlog_{i}.jpg",
public_base_url="http://example.com",
project=self.project,
)
for i in range(3)
]
queue_images_to_nats(job, images)

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch_size": 2}, format="json")
self.assertEqual(resp.status_code, 200)

job.refresh_from_db()
joined = "\n".join(job.logs.stdout)
self.assertIn("Tasks fetched", joined)
self.assertIn("requested=2", joined)
self.assertIn("delivered=", joined)
self.assertIn(self.user.email, joined)

def test_tasks_endpoint_logs_early_exit_for_terminal_job(self):
"""Polling a terminal-status job produces an empty response and a 'non-active job' log line."""
pipeline = self._create_pipeline()
job = self._create_ml_job("Job for early-exit log test", pipeline)
job.dispatch_mode = JobDispatchMode.ASYNC_API
job.status = JobState.SUCCESS
job.save(update_fields=["dispatch_mode", "status"])

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch_size": 5}, format="json")
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.json(), {"tasks": []})

job.refresh_from_db()
joined = "\n".join(job.logs.stdout)
self.assertIn("non-active job", joined)
self.assertIn(f"status={JobState.SUCCESS}", joined)

def test_result_endpoint_mirrors_queued_log_to_job_logger(self):
"""The result endpoint mirrors its 'Queued pipeline result' line to the per-job logger."""
from unittest.mock import MagicMock, patch

pipeline = self._create_pipeline()
job = self._create_ml_job("Job for result-logging test", pipeline)

self.client.force_authenticate(user=self.user)
result_url = reverse_with_params("api:job-result", args=[job.pk], params={"project_id": self.project.pk})

result_data = {
"results": [
{
"reply_subject": "test.reply.logged",
"result": {
"pipeline": "test-pipeline",
"algorithms": {},
"total_time": 0.1,
"source_images": [],
"detections": [],
"errors": None,
},
}
]
}

# Keep the Celery task from actually running; the log line is emitted
# by the view before delegating to Celery.
mock_async_result = MagicMock()
mock_async_result.id = "mirrored-task-id"
with patch("ami.jobs.views.process_nats_pipeline_result.delay", return_value=mock_async_result):
resp = self.client.post(result_url, result_data, format="json")
self.assertEqual(resp.status_code, 200)

job.refresh_from_db()
joined = "\n".join(job.logs.stdout)
self.assertIn("Queued pipeline result", joined)
self.assertIn("mirrored-task-id", joined)
self.assertIn("test.reply.logged", joined)
self.assertIn(self.user.email, joined)

def test_tasks_fetch_log_uses_token_fingerprint_not_full_token(self):
"""
Fix 1: token written to per-job logs is truncated to 8 chars + ellipsis,
never the full 40-char DRF bearer secret.
"""
from rest_framework.authtoken.models import Token

pipeline = self._create_pipeline()
job = self._create_ml_job("Job for token-fingerprint test", pipeline)
job.dispatch_mode = JobDispatchMode.ASYNC_API
job.status = JobState.STARTED
job.save(update_fields=["dispatch_mode", "status"])
images = [
SourceImage.objects.create(
path=f"tokentest_{i}.jpg",
public_base_url="http://example.com",
project=self.project,
)
for i in range(2)
]
queue_images_to_nats(job, images)

token, _ = Token.objects.get_or_create(user=self.user)
# Authenticate with the actual token object so request.auth.pk is set
self.client.force_authenticate(user=self.user, token=token)

tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch_size": 2}, format="json")
self.assertEqual(resp.status_code, 200)

job.refresh_from_db()
joined = "\n".join(job.logs.stdout)
# Full token key must NOT appear anywhere in logs
self.assertNotIn(token.key, joined)
# Fingerprint (first 8 chars + ellipsis) MUST appear
expected_fingerprint = f"{token.key[:8]}…"
self.assertIn(expected_fingerprint, joined)

def test_tasks_fetch_zero_delivered_does_not_log_to_stdout(self):
"""
Fix 2: when delivered==0, the log line is emitted at DEBUG and must not
land in job.logs.stdout (JobLogHandler only captures INFO and above).
"""
pipeline = self._create_pipeline()
job = self._create_ml_job("Job for zero-delivered test", pipeline)
job.dispatch_mode = JobDispatchMode.ASYNC_API
job.status = JobState.STARTED
job.save(update_fields=["dispatch_mode", "status"])
# Do NOT queue any images — NATS will return 0 tasks.

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch_size": 5}, format="json")
self.assertEqual(resp.status_code, 200)
self.assertEqual(len(resp.json()["tasks"]), 0)

job.refresh_from_db()
# No Tasks fetched line should appear in stdout for a zero-delivery poll
joined = "\n".join(job.logs.stdout)
self.assertNotIn("Tasks fetched", joined)

def test_tasks_fetch_nonzero_delivered_logs_to_stdout(self):
"""
Fix 2: when delivered>0, the log line is emitted at INFO and lands in
job.logs.stdout with the correct delivered count.
"""
pipeline = self._create_pipeline()
job = self._create_ml_job("Job for nonzero-delivered test", pipeline)
job.dispatch_mode = JobDispatchMode.ASYNC_API
job.status = JobState.STARTED
job.save(update_fields=["dispatch_mode", "status"])
images = [
SourceImage.objects.create(
path=f"nonzero_{i}.jpg",
public_base_url="http://example.com",
project=self.project,
)
for i in range(3)
]
queue_images_to_nats(job, images)

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch_size": 3}, format="json")
self.assertEqual(resp.status_code, 200)
self.assertEqual(len(resp.json()["tasks"]), 3)

job.refresh_from_db()
joined = "\n".join(job.logs.stdout)
self.assertIn("Tasks fetched", joined)
self.assertIn("delivered=3", joined)


class TestJobThroughputLogging(TestCase):
"""Unit tests for _log_job_throughput (Task 3)."""

def setUp(self):
self.project = Project.objects.create(name="Throughput Test Project")
self.pipeline = Pipeline.objects.create(name="Throughput Pipeline", slug="throughput-pipeline")
self.pipeline.projects.add(self.project)
self.job = Job.objects.create(
job_type_key=MLJob.key,
project=self.project,
name="Throughput job",
pipeline=self.pipeline,
)

def _seed_process_stage(self, processed: int, remaining: int) -> None:
self.job.progress.add_stage("process")
self.job.progress.update_stage(
"process",
progress=processed / max(1, processed + remaining),
status=JobState.STARTED,
processed=processed,
remaining=remaining,
failed=0,
)
self.job.save()

def test_throughput_line_is_well_formed(self):
import datetime

from ami.jobs.tasks import _log_job_throughput

self._seed_process_stage(processed=10, remaining=90)
self.job.started_at = datetime.datetime.now() - datetime.timedelta(minutes=5)
self.job.save(update_fields=["started_at"])

_log_job_throughput(self.job, "process")

self.job.refresh_from_db()
joined = "\n".join(self.job.logs.stdout)
self.assertIn("throughput", joined)
self.assertIn("processed=10/100", joined)
self.assertIn("rate=2.0 imgs/min", joined)
# ETA for 90 remaining at 2.0 imgs/min = 45 minutes
self.assertIn("ETA=45m", joined)

def test_throughput_skipped_when_started_at_is_none(self):
from ami.jobs.tasks import _log_job_throughput

self._seed_process_stage(processed=5, remaining=5)
self.assertIsNone(self.job.started_at)

_log_job_throughput(self.job, "process")

self.job.refresh_from_db()
joined = "\n".join(self.job.logs.stdout)
self.assertNotIn("throughput", joined)

def test_throughput_skipped_for_non_processing_stage(self):
import datetime

from ami.jobs.tasks import _log_job_throughput

self._seed_process_stage(processed=10, remaining=90)
self.job.started_at = datetime.datetime.now() - datetime.timedelta(minutes=5)
self.job.save(update_fields=["started_at"])

_log_job_throughput(self.job, "delay")

self.job.refresh_from_db()
joined = "\n".join(self.job.logs.stdout)
self.assertNotIn("throughput", joined)

def test_throughput_with_zero_processed_reports_unknown_eta(self):
import datetime

from ami.jobs.tasks import _log_job_throughput

self._seed_process_stage(processed=0, remaining=50)
self.job.started_at = datetime.datetime.now() - datetime.timedelta(minutes=5)
self.job.save(update_fields=["started_at"])

_log_job_throughput(self.job, "process")

self.job.refresh_from_db()
joined = "\n".join(self.job.logs.stdout)
self.assertIn("processed=0/50", joined)
self.assertIn("rate=0.0", joined)
self.assertIn("ETA=unknown", joined)


class TestJobDispatchModeFiltering(APITestCase):
"""Test job filtering by dispatch_mode."""
Expand Down
Loading
Loading