Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,65 @@ def _get_current_counts_from_job_progress(job, stage: str) -> tuple[int, int, in
return 0, 0, 0


def _format_elapsed(seconds: float) -> str:
"""Render a duration as `Hh Mm Ss` (hours omitted when zero)."""
total = max(0, int(seconds))
h, rem = divmod(total, 3600)
m, s = divmod(rem, 60)
if h > 0:
return f"{h}h {m:02d}m {s:02d}s"
return f"{m}m {s:02d}s"


def _log_job_throughput(job, stage: str) -> None:
"""
Emit a per-job throughput/ETA line so operators can distinguish stalled-vs-slow
vs healthy-but-throttled jobs at a glance in the per-job log view.

Intentionally a plain division over total elapsed time, not a rolling-window
estimate or forecast — accurate enough to spot a stall, cheap to compute, and
easy to interpret from a single log line.
"""
if stage not in ("process", "results"):
return
if not job.started_at:
return
elapsed_seconds = (datetime.datetime.now() - job.started_at).total_seconds()
elapsed_minutes = elapsed_seconds / 60.0
if elapsed_minutes < 0.05:
# Ratio over <3s of elapsed time is noise, not signal.
return

# The process stage holds the authoritative processed/remaining counts
# (results stage only tracks detection/classification/capture counts).
try:
process_stage = job.progress.get_stage("process")
except (ValueError, AttributeError):
return

processed = 0
remaining = 0
for param in getattr(process_stage, "params", []) or []:
if param.key == "processed":
processed = param.value or 0
elif param.key == "remaining":
remaining = param.value or 0
total = processed + remaining

if processed == 0:
rate_str = "rate=0.0 imgs/min, ETA=unknown"
else:
rate = processed / elapsed_minutes
remaining_imgs = max(0, total - processed)
eta_seconds = (remaining_imgs / rate) * 60.0 if rate > 0 else 0.0
rate_str = f"rate={rate:.1f} imgs/min, ETA={_format_elapsed(eta_seconds)}"

job.logger.info(
f"Job {job.pk} throughput: elapsed={_format_elapsed(elapsed_seconds)}, "
f"processed={processed}/{total}, {rate_str}"
)


def _update_job_progress(
job_id: int, stage: str, progress_percentage: float, complete_state: "JobState", **state_params
) -> None:
Expand Down Expand Up @@ -412,6 +471,7 @@ def _update_job_progress(
job.finished_at = datetime.datetime.now() # Use naive datetime in local time
job.logger.info(f"Updated job {job_id} progress in stage '{stage}' to {progress_percentage*100}%")
job.save()
_log_job_throughput(job, stage)
Comment thread
mihow marked this conversation as resolved.
Outdated

# Clean up async resources for completed jobs that use NATS/Redis
if job.progress.is_complete():
Expand Down
179 changes: 179 additions & 0 deletions ami/jobs/tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,185 @@ def test_result_endpoint_validation(self):
resp = self.client.post(result_url, bare_list, format="json")
self.assertEqual(resp.status_code, 400)

def test_tasks_endpoint_logs_fetch_to_job_logger(self):
"""Successful task-fetch lands a 'Tasks fetched' line on the per-job logger."""
pipeline = self._create_pipeline()
job = self._create_ml_job("Job for fetch-logging test", pipeline)
job.dispatch_mode = JobDispatchMode.ASYNC_API
job.status = JobState.STARTED
job.save(update_fields=["dispatch_mode", "status"])
images = [
SourceImage.objects.create(
path=f"fetchlog_{i}.jpg",
public_base_url="http://example.com",
project=self.project,
)
for i in range(3)
]
queue_images_to_nats(job, images)

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch_size": 2}, format="json")
self.assertEqual(resp.status_code, 200)

job.refresh_from_db()
joined = "\n".join(job.logs.stdout)
self.assertIn("Tasks fetched", joined)
self.assertIn("requested=2", joined)
self.assertIn("delivered=", joined)
self.assertIn(self.user.email, joined)

def test_tasks_endpoint_logs_early_exit_for_terminal_job(self):
"""Polling a terminal-status job produces an empty response and a 'non-active job' log line."""
pipeline = self._create_pipeline()
job = self._create_ml_job("Job for early-exit log test", pipeline)
job.dispatch_mode = JobDispatchMode.ASYNC_API
job.status = JobState.SUCCESS
job.save(update_fields=["dispatch_mode", "status"])

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch_size": 5}, format="json")
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.json(), {"tasks": []})

job.refresh_from_db()
joined = "\n".join(job.logs.stdout)
self.assertIn("non-active job", joined)
self.assertIn(f"status={JobState.SUCCESS}", joined)

def test_result_endpoint_mirrors_queued_log_to_job_logger(self):
"""The result endpoint mirrors its 'Queued pipeline result' line to the per-job logger."""
from unittest.mock import MagicMock, patch

pipeline = self._create_pipeline()
job = self._create_ml_job("Job for result-logging test", pipeline)

self.client.force_authenticate(user=self.user)
result_url = reverse_with_params("api:job-result", args=[job.pk], params={"project_id": self.project.pk})

result_data = {
"results": [
{
"reply_subject": "test.reply.logged",
"result": {
"pipeline": "test-pipeline",
"algorithms": {},
"total_time": 0.1,
"source_images": [],
"detections": [],
"errors": None,
},
}
]
}

# Keep the Celery task from actually running; the log line is emitted
# by the view before delegating to Celery.
mock_async_result = MagicMock()
mock_async_result.id = "mirrored-task-id"
with patch("ami.jobs.views.process_nats_pipeline_result.delay", return_value=mock_async_result):
resp = self.client.post(result_url, result_data, format="json")
self.assertEqual(resp.status_code, 200)

job.refresh_from_db()
joined = "\n".join(job.logs.stdout)
self.assertIn("Queued pipeline result", joined)
self.assertIn("mirrored-task-id", joined)
self.assertIn("test.reply.logged", joined)
self.assertIn(self.user.email, joined)


class TestJobThroughputLogging(TestCase):
"""Unit tests for _log_job_throughput (Task 3)."""

def setUp(self):
self.project = Project.objects.create(name="Throughput Test Project")
self.pipeline = Pipeline.objects.create(name="Throughput Pipeline", slug="throughput-pipeline")
self.pipeline.projects.add(self.project)
self.job = Job.objects.create(
job_type_key=MLJob.key,
project=self.project,
name="Throughput job",
pipeline=self.pipeline,
)

def _seed_process_stage(self, processed: int, remaining: int) -> None:
self.job.progress.add_stage("process")
self.job.progress.update_stage(
"process",
progress=processed / max(1, processed + remaining),
status=JobState.STARTED,
processed=processed,
remaining=remaining,
failed=0,
)
self.job.save()

def test_throughput_line_is_well_formed(self):
import datetime

from ami.jobs.tasks import _log_job_throughput

self._seed_process_stage(processed=10, remaining=90)
self.job.started_at = datetime.datetime.now() - datetime.timedelta(minutes=5)
self.job.save(update_fields=["started_at"])

_log_job_throughput(self.job, "process")

self.job.refresh_from_db()
joined = "\n".join(self.job.logs.stdout)
self.assertIn("throughput", joined)
self.assertIn("processed=10/100", joined)
self.assertIn("rate=2.0 imgs/min", joined)
# ETA for 90 remaining at 2.0 imgs/min = 45 minutes
self.assertIn("ETA=45m", joined)

def test_throughput_skipped_when_started_at_is_none(self):
from ami.jobs.tasks import _log_job_throughput

self._seed_process_stage(processed=5, remaining=5)
self.assertIsNone(self.job.started_at)

_log_job_throughput(self.job, "process")

self.job.refresh_from_db()
joined = "\n".join(self.job.logs.stdout)
self.assertNotIn("throughput", joined)

def test_throughput_skipped_for_non_processing_stage(self):
import datetime

from ami.jobs.tasks import _log_job_throughput

self._seed_process_stage(processed=10, remaining=90)
self.job.started_at = datetime.datetime.now() - datetime.timedelta(minutes=5)
self.job.save(update_fields=["started_at"])

_log_job_throughput(self.job, "delay")

self.job.refresh_from_db()
joined = "\n".join(self.job.logs.stdout)
self.assertNotIn("throughput", joined)

def test_throughput_with_zero_processed_reports_unknown_eta(self):
import datetime

from ami.jobs.tasks import _log_job_throughput

self._seed_process_stage(processed=0, remaining=50)
self.job.started_at = datetime.datetime.now() - datetime.timedelta(minutes=5)
self.job.save(update_fields=["started_at"])

_log_job_throughput(self.job, "process")

self.job.refresh_from_db()
joined = "\n".join(self.job.logs.stdout)
self.assertIn("processed=0/50", joined)
self.assertIn("rate=0.0", joined)
self.assertIn("ETA=unknown", joined)


class TestJobDispatchModeFiltering(APITestCase):
"""Test job filtering by dispatch_mode."""
Expand Down
28 changes: 26 additions & 2 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,17 @@ def tasks(self, request, pk=None):
if job.dispatch_mode != JobDispatchMode.ASYNC_API:
raise ValidationError("Only async_api jobs have fetchable tasks")

# Only serve tasks for actively processing jobs
user_desc = getattr(request.user, "email", None) or str(request.user)
token_id = getattr(request.auth, "pk", None)

Comment thread
mihow marked this conversation as resolved.
Outdated
# Only serve tasks for actively processing jobs. Logging the early-exit
# makes "phantom-pull" workers (polling against terminal jobs whose NATS
# stream still exists) visible from the per-job log view.
if job.status not in JobState.active_states():
job.logger.info(
f"Tasks requested for non-active job (status={job.status}); returning empty. "
f"user={user_desc}, token_id={token_id}"
)
return Response({"tasks": []})

# Validate that the job has a pipeline
Expand All @@ -288,9 +297,14 @@ async def get_tasks():
try:
tasks = async_to_sync(get_tasks)()
except (asyncio.TimeoutError, OSError, nats.errors.Error) as e:
logger.warning("NATS unavailable while fetching tasks for job %s: %s", job.pk, e)
msg = f"NATS unavailable while fetching tasks for job {job.pk}: {e}"
logger.warning(msg)
job.logger.warning(f"{msg} user={user_desc}, token_id={token_id}")
return Response({"error": "Task queue temporarily unavailable"}, status=503)

job.logger.info(
f"Tasks fetched: requested={batch_size}, delivered={len(tasks)}, " f"user={user_desc}, token_id={token_id}"
)
return Response({"tasks": tasks})

@extend_schema(
Expand All @@ -313,6 +327,9 @@ def result(self, request, pk=None):
# Record heartbeat for async processing services on this pipeline
_mark_pipeline_pull_services_seen(job)

user_desc = getattr(request.user, "email", None) or str(request.user)
token_id = getattr(request.auth, "pk", None)

serializer = MLJobResultsRequestSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
validated_results = serializer.validated_data["results"]
Expand Down Expand Up @@ -344,6 +361,13 @@ def result(self, request, pk=None):
task.id,
reply_subject,
)
# Mirror to per-job logger so the job log view shows result-POST
# activity alongside task-fetch activity. Module-logger line above
# stays for ops-level monitoring outside the per-job context.
job.logger.info(
f"Queued pipeline result: task_id={task.id}, reply_subject={reply_subject}, "
f"user={user_desc}, token_id={token_id}"
Comment thread
mihow marked this conversation as resolved.
Outdated
)

return Response(
{
Expand Down
Loading