Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions ami/jobs/schemas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from drf_spectacular.utils import OpenApiParameter
from rest_framework import serializers

ids_only_param = OpenApiParameter(
name="ids_only",
Expand All @@ -14,9 +15,6 @@
type=bool,
)

batch_param = OpenApiParameter(
name="batch",
description="Number of tasks to retrieve",
required=False,
type=int,
)

class TasksRequestSerializer(serializers.Serializer):
batch = serializers.IntegerField(min_value=1, required=True)
18 changes: 6 additions & 12 deletions ami/jobs/tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,8 @@ def _task_batch_helper(self, value: Any, expected_status: int):
queue_images_to_nats(job, images)

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params(
"api:job-tasks", args=[job.pk], params={"project_id": self.project.pk, "batch": value}
)
resp = self.client.get(tasks_url)
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch": value}, format="json")
self.assertEqual(resp.status_code, expected_status)
return resp.json()

Expand Down Expand Up @@ -523,10 +521,8 @@ def test_tasks_endpoint_without_pipeline(self):
)

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params(
"api:job-tasks", args=[job.pk], params={"project_id": self.project.pk, "batch": 1}
)
resp = self.client.get(tasks_url)
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch": 1}, format="json")

self.assertEqual(resp.status_code, 400)
self.assertIn("pipeline", resp.json()[0].lower())
Expand Down Expand Up @@ -722,9 +718,7 @@ def test_tasks_endpoint_rejects_non_async_jobs(self):
)

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params(
"api:job-tasks", args=[sync_job.pk], params={"project_id": self.project.pk, "batch": 1}
)
resp = self.client.get(tasks_url)
tasks_url = reverse_with_params("api:job-tasks", args=[sync_job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch": 1}, format="json")
self.assertEqual(resp.status_code, 400)
self.assertIn("async_api", resp.json()[0].lower())
55 changes: 30 additions & 25 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from ami.base.permissions import ObjectPermission
from ami.base.views import ProjectMixin
from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param
from ami.jobs.schemas import TasksRequestSerializer, ids_only_param, incomplete_only_param
from ami.jobs.tasks import process_nats_pipeline_result
from ami.main.api.schemas import project_id_doc_param
from ami.main.api.views import DefaultViewSet
Expand Down Expand Up @@ -238,24 +238,25 @@ def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)

@extend_schema(
parameters=[batch_param],
request=TasksRequestSerializer,
responses={200: dict},
Comment thread
mihow marked this conversation as resolved.
Outdated
parameters=[project_id_doc_param],
)
@action(detail=True, methods=["get"], name="tasks")
@action(detail=True, methods=["post"], name="tasks")
def tasks(self, request, pk=None):
"""
Get tasks from the job queue.
Fetch tasks from the job queue (POST).

Returns task data with reply_subject for acknowledgment. External workers should:
1. Call this endpoint to get tasks
1. POST to this endpoint with {"batch": N}
2. Process the tasks
3. POST to /jobs/{id}/result/ with the reply_subject to acknowledge
3. POST to /jobs/{id}/result/ with the results
"""
serializer = TasksRequestSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
batch = serializer.validated_data["batch"]

job: Job = self.get_object()
try:
batch = IntegerField(required=True, min_value=1).clean(request.query_params.get("batch"))
except Exception as e:
raise ValidationError({"batch": str(e)}) from e

# Only async_api jobs have tasks fetchable from NATS
if job.dispatch_mode != JobDispatchMode.ASYNC_API:
Expand Down Expand Up @@ -290,31 +291,33 @@ async def get_tasks():
@action(detail=True, methods=["post"], name="result")
def result(self, request, pk=None):
"""
The request body should be a list of results: list[PipelineTaskResult]
Submit pipeline results.

Accepts: {"results": [PipelineTaskResult, ...]}
Or legacy: [PipelineTaskResult, ...] (bare list)

This endpoint accepts a list of pipeline results and queues them for
background processing. Each result will be validated, saved to the database,
and acknowledged via NATS in a Celery task.
Results are validated then queued for background processing via Celery.
"""

job = self.get_object()

# Record heartbeat for async processing services on this pipeline
_mark_pipeline_pull_services_seen(job)

# Validate request data is a list
# Accept both wrapped format and legacy bare list
if isinstance(request.data, list):
results = request.data
raw_results = request.data
elif isinstance(request.data, dict) and "results" in request.data:
raw_results = request.data["results"]
if not isinstance(raw_results, list):
raise ValidationError("'results' must be a list")
else:
results = [request.data]
raw_results = [request.data]

try:
# Pre-validate all results before enqueuing any tasks
# This prevents partial queueing and duplicate task processing
validated_results = []
for item in results:
task_result = PipelineTaskResult(**item)
validated_results.append(task_result)
validated_results = pydantic.parse_obj_as(list[PipelineTaskResult], raw_results)

# All validation passed, now queue all tasks
queued_tasks = []
Expand All @@ -337,23 +340,25 @@ def result(self, request, pk=None):
)

logger.info(
f"Queued pipeline result processing for job {job.pk}, "
f"task_id: {task.id}, reply_subject: {reply_subject}"
"Queued pipeline result for job %s, task_id: %s, reply_subject: %s",
job.pk,
task.id,
reply_subject,
)

return Response(
{
"status": "accepted",
"job_id": job.pk,
"results_queued": len([t for t in queued_tasks if t["status"] == "queued"]),
"results_queued": len(queued_tasks),
"tasks": queued_tasks,
}
)
except pydantic.ValidationError as e:
raise ValidationError(f"Invalid result data: {e}") from e

except Exception as e:
logger.error(f"Failed to queue pipeline results for job {job.pk}: {e}")
logger.error("Failed to queue pipeline results for job %s: %s", job.pk, e)
return Response(
{
"status": "error",
Expand Down
Loading