Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions ami/jobs/schemas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from django_pydantic_field.rest_framework import SchemaField
from drf_spectacular.utils import OpenApiParameter
from rest_framework import serializers

from ami.ml.schemas import PipelineTaskResult, ProcessingServiceClientInfo

ids_only_param = OpenApiParameter(
name="ids_only",
Expand All @@ -14,9 +18,31 @@
type=bool,
)

batch_param = OpenApiParameter(
name="batch",
description="Number of tasks to retrieve",
required=False,
type=int,
)

class TasksRequestSerializer(serializers.Serializer):
"""POST /jobs/{id}/tasks/ request body. Fetch tasks from the job queue."""

batch = serializers.IntegerField(min_value=1, required=True)
client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None)


class TasksResponseSerializer(serializers.Serializer):
"""POST /jobs/{id}/tasks/ response body. Tasks returned to the processing service."""

tasks = serializers.ListField(child=serializers.DictField(), default=[])


class PipelineResultsRequestSerializer(serializers.Serializer):
"""POST /jobs/{id}/result/ request body. Submit pipeline results for processing."""

results = SchemaField(schema=list[PipelineTaskResult])
client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None)


class PipelineResultsResponseSerializer(serializers.Serializer):
"""POST /jobs/{id}/result/ response body. Acknowledgment of queued results."""

status = serializers.CharField()
job_id = serializers.IntegerField()
results_queued = serializers.IntegerField()
tasks = serializers.ListField(child=serializers.DictField(), default=[])
57 changes: 28 additions & 29 deletions ami/jobs/tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,8 @@ def _task_batch_helper(self, value: Any, expected_status: int):
queue_images_to_nats(job, images)

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params(
"api:job-tasks", args=[job.pk], params={"project_id": self.project.pk, "batch": value}
)
resp = self.client.get(tasks_url)
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch": value}, format="json")
self.assertEqual(resp.status_code, expected_status)
return resp.json()

Expand Down Expand Up @@ -523,10 +521,8 @@ def test_tasks_endpoint_without_pipeline(self):
)

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params(
"api:job-tasks", args=[job.pk], params={"project_id": self.project.pk, "batch": 1}
)
resp = self.client.get(tasks_url)
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch": 1}, format="json")

self.assertEqual(resp.status_code, 400)
self.assertIn("pipeline", resp.json()[0].lower())
Expand All @@ -541,19 +537,21 @@ def test_result_endpoint_stub(self):
"api:job-result", args=[job.pk], params={"project_id": self.project.pk, "batch": 1}
)

result_data = [
{
"reply_subject": "test.reply.1",
"result": {
"pipeline": "test-pipeline",
"algorithms": {},
"total_time": 1.5,
"source_images": [],
"detections": [],
"errors": None,
},
}
]
result_data = {
"results": [
{
"reply_subject": "test.reply.1",
"result": {
"pipeline": "test-pipeline",
"algorithms": {},
"total_time": 1.5,
"source_images": [],
"detections": [],
"errors": None,
},
}
]
}

resp = self.client.post(result_url, result_data, format="json")

Expand All @@ -572,16 +570,19 @@ def test_result_endpoint_validation(self):
result_url = reverse_with_params("api:job-result", args=[job.pk], params={"project_id": self.project.pk})

# Test with missing reply_subject
invalid_data = [{"result": {"pipeline": "test"}}]
invalid_data = {"results": [{"result": {"pipeline": "test"}}]}
resp = self.client.post(result_url, invalid_data, format="json")
self.assertEqual(resp.status_code, 400)
self.assertIn("reply_subject", resp.json()[0].lower())

# Test with missing result
invalid_data = [{"reply_subject": "test.reply"}]
invalid_data = {"results": [{"reply_subject": "test.reply"}]}
resp = self.client.post(result_url, invalid_data, format="json")
self.assertEqual(resp.status_code, 400)
self.assertIn("result", resp.json()[0].lower())

# Test with bare list (no longer accepted)
bare_list = [{"reply_subject": "test.reply", "result": {"pipeline": "test"}}]
resp = self.client.post(result_url, bare_list, format="json")
self.assertEqual(resp.status_code, 400)


class TestJobDispatchModeFiltering(APITestCase):
Expand Down Expand Up @@ -722,9 +723,7 @@ def test_tasks_endpoint_rejects_non_async_jobs(self):
)

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params(
"api:job-tasks", args=[sync_job.pk], params={"project_id": self.project.pk, "batch": 1}
)
resp = self.client.get(tasks_url)
tasks_url = reverse_with_params("api:job-tasks", args=[sync_job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch": 1}, format="json")
self.assertEqual(resp.status_code, 400)
self.assertIn("async_api", resp.json()[0].lower())
72 changes: 37 additions & 35 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging

import nats.errors
import pydantic
from asgiref.sync import async_to_sync
from django.db.models import Q
from django.db.models.query import QuerySet
Expand All @@ -17,11 +16,17 @@

from ami.base.permissions import ObjectPermission
from ami.base.views import ProjectMixin
from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param
from ami.jobs.schemas import (
PipelineResultsRequestSerializer,
PipelineResultsResponseSerializer,
TasksRequestSerializer,
TasksResponseSerializer,
ids_only_param,
incomplete_only_param,
)
from ami.jobs.tasks import process_nats_pipeline_result
from ami.main.api.schemas import project_id_doc_param
from ami.main.api.views import DefaultViewSet
from ami.ml.schemas import PipelineTaskResult
from ami.utils.fields import url_boolean_param

from .models import Job, JobDispatchMode, JobState
Expand Down Expand Up @@ -238,24 +243,25 @@ def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)

@extend_schema(
parameters=[batch_param],
responses={200: dict},
request=TasksRequestSerializer,
responses={200: TasksResponseSerializer},
parameters=[project_id_doc_param],
)
@action(detail=True, methods=["get"], name="tasks")
@action(detail=True, methods=["post"], name="tasks")
def tasks(self, request, pk=None):
"""
Get tasks from the job queue.
Fetch tasks from the job queue (POST).

Returns task data with reply_subject for acknowledgment. External workers should:
1. Call this endpoint to get tasks
1. POST to this endpoint with {"batch": N}
2. Process the tasks
3. POST to /jobs/{id}/result/ with the reply_subject to acknowledge
3. POST to /jobs/{id}/result/ with the results
"""
serializer = TasksRequestSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
batch = serializer.validated_data["batch"]

job: Job = self.get_object()
try:
batch = IntegerField(required=True, min_value=1).clean(request.query_params.get("batch"))
except Exception as e:
raise ValidationError({"batch": str(e)}) from e

# Only async_api jobs have tasks fetchable from NATS
if job.dispatch_mode != JobDispatchMode.ASYNC_API:
Expand Down Expand Up @@ -287,35 +293,31 @@ async def get_tasks():

return Response({"tasks": tasks})

@extend_schema(
request=PipelineResultsRequestSerializer,
responses={200: PipelineResultsResponseSerializer},
parameters=[project_id_doc_param],
)
@action(detail=True, methods=["post"], name="result")
def result(self, request, pk=None):
"""
The request body should be a list of results: list[PipelineTaskResult]
Submit pipeline results.

Accepts: {"results": [PipelineTaskResult, ...]}

This endpoint accepts a list of pipeline results and queues them for
background processing. Each result will be validated, saved to the database,
and acknowledged via NATS in a Celery task.
Results are validated then queued for background processing via Celery.
"""

job = self.get_object()

# Record heartbeat for async processing services on this pipeline
_mark_pipeline_pull_services_seen(job)

# Validate request data is a list
if isinstance(request.data, list):
results = request.data
else:
results = [request.data]
serializer = PipelineResultsRequestSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
validated_results = serializer.validated_data["results"]
Comment thread
mihow marked this conversation as resolved.
Outdated

try:
# Pre-validate all results before enqueuing any tasks
# This prevents partial queueing and duplicate task processing
validated_results = []
for item in results:
task_result = PipelineTaskResult(**item)
validated_results.append(task_result)

# All validation passed, now queue all tasks
queued_tasks = []
for task_result in validated_results:
Expand All @@ -337,23 +339,23 @@ def result(self, request, pk=None):
)

logger.info(
f"Queued pipeline result processing for job {job.pk}, "
f"task_id: {task.id}, reply_subject: {reply_subject}"
"Queued pipeline result for job %s, task_id: %s, reply_subject: %s",
job.pk,
task.id,
reply_subject,
)

return Response(
{
"status": "accepted",
"job_id": job.pk,
"results_queued": len([t for t in queued_tasks if t["status"] == "queued"]),
"results_queued": len(queued_tasks),
"tasks": queued_tasks,
}
)
except pydantic.ValidationError as e:
raise ValidationError(f"Invalid result data: {e}") from e

except Exception as e:
logger.error(f"Failed to queue pipeline results for job {job.pk}: {e}")
logger.error("Failed to queue pipeline results for job %s: %s", job.pk, e)
return Response(
{
"status": "error",
Expand Down
14 changes: 14 additions & 0 deletions ami/ml/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,20 @@ class PipelineProcessingTask(pydantic.BaseModel):
# config: PipelineRequestConfigParameters | dict | None = None


class ProcessingServiceClientInfo(pydantic.BaseModel):
"""Identity metadata for a specific processing service instance.

A single ProcessingService may have multiple workers/pods.
This identifies which one is making the request.
"""

hostname: str = ""
software: str = ""
version: str = ""
platform: str = ""
pod_name: str = ""
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated


class PipelineTaskResult(pydantic.BaseModel):
"""
The result from processing a single PipelineProcessingTask.
Expand Down
Loading