diff --git a/ami/jobs/schemas.py b/ami/jobs/schemas.py index 0e1ea4ac7..74af39ce9 100644 --- a/ami/jobs/schemas.py +++ b/ami/jobs/schemas.py @@ -1,5 +1,15 @@ +import pydantic from drf_spectacular.utils import OpenApiParameter + +class QueuedTaskAcknowledgment(pydantic.BaseModel): + """Acknowledgment for a single result that was queued for background processing.""" + + reply_subject: str + status: str + task_id: str + + ids_only_param = OpenApiParameter( name="ids_only", description="Return only job IDs instead of full objects", @@ -13,10 +23,3 @@ required=False, type=bool, ) - -batch_param = OpenApiParameter( - name="batch", - description="Number of tasks to retrieve", - required=False, - type=int, -) diff --git a/ami/jobs/serializers.py b/ami/jobs/serializers.py index d903b0812..fc2fcf8be 100644 --- a/ami/jobs/serializers.py +++ b/ami/jobs/serializers.py @@ -10,9 +10,11 @@ ) from ami.main.models import Deployment, Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline +from ami.ml.schemas import PipelineProcessingTask, PipelineTaskResult, ProcessingServiceClientInfo from ami.ml.serializers import PipelineNestedSerializer from .models import Job, JobLogs, JobProgress, MLJob +from .schemas import QueuedTaskAcknowledgment class JobProjectNestedSerializer(DefaultSerializer): @@ -163,3 +165,53 @@ class MinimalJobSerializer(DefaultSerializer): class Meta: model = Job fields = ["id", "pipeline_slug"] + + +class MLJobTasksRequestSerializer(serializers.Serializer): + """POST /jobs/{id}/tasks/ — request body sent by a processing service to fetch work. + + The processing service polls this endpoint to get tasks (images) to process. + Each task is a PipelineProcessingTask with an image URL and a NATS reply subject. + """ + + batch_size = serializers.IntegerField(min_value=1, required=True) + client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None) + + +class MLJobTasksResponseSerializer(serializers.Serializer): + """POST /jobs/{id}/tasks/ — response body returned to the processing service. + + Contains a list of tasks (PipelineProcessingTask dicts) for the worker to process. + Each task includes an image URL, task ID, and reply_subject for result correlation. + Returns an empty list when no tasks are available or the job is not active. + """ + + tasks = SchemaField(schema=list[PipelineProcessingTask], default=[]) + + +class MLJobResultsRequestSerializer(serializers.Serializer): + """POST /jobs/{id}/result/ — request body sent by a processing service to deliver results. + + "Request" here refers to the HTTP request to Antenna, not a request for work. + The processing service has finished processing tasks and is posting its results + (successes or errors) back. Each PipelineTaskResult contains a reply_subject + (correlating back to the original task) and a result payload that is either a + PipelineResultsResponse (success) or PipelineResultsError (failure). + """ + + results = SchemaField(schema=list[PipelineTaskResult]) + client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None) + + +class MLJobResultsResponseSerializer(serializers.Serializer): + """POST /jobs/{id}/result/ — acknowledgment returned to the processing service. + + Confirms receipt and indicates how many results were queued for background + processing via Celery. Individual task entries include their Celery task_id + for traceability. + """ + + status = serializers.CharField() + job_id = serializers.IntegerField() + results_queued = serializers.IntegerField() + tasks = SchemaField(schema=list[QueuedTaskAcknowledgment], default=[]) diff --git a/ami/jobs/tests/test_jobs.py b/ami/jobs/tests/test_jobs.py index 7f2607bfe..7241b0a57 100644 --- a/ami/jobs/tests/test_jobs.py +++ b/ami/jobs/tests/test_jobs.py @@ -489,10 +489,8 @@ def _task_batch_helper(self, value: Any, expected_status: int): queue_images_to_nats(job, images) self.client.force_authenticate(user=self.user) - tasks_url = reverse_with_params( - "api:job-tasks", args=[job.pk], params={"project_id": self.project.pk, "batch": value} - ) - resp = self.client.get(tasks_url) + tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk}) + resp = self.client.post(tasks_url, {"batch_size": value}, format="json") self.assertEqual(resp.status_code, expected_status) return resp.json() @@ -523,10 +521,8 @@ def test_tasks_endpoint_without_pipeline(self): ) self.client.force_authenticate(user=self.user) - tasks_url = reverse_with_params( - "api:job-tasks", args=[job.pk], params={"project_id": self.project.pk, "batch": 1} - ) - resp = self.client.get(tasks_url) + tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk}) + resp = self.client.post(tasks_url, {"batch_size": 1}, format="json") self.assertEqual(resp.status_code, 400) self.assertIn("pipeline", resp.json()[0].lower()) @@ -537,23 +533,23 @@ def test_result_endpoint_stub(self): job = self._create_ml_job("Job for results test", pipeline) self.client.force_authenticate(user=self.user) - result_url = reverse_with_params( - "api:job-result", args=[job.pk], params={"project_id": self.project.pk, "batch": 1} - ) + result_url = reverse_with_params("api:job-result", args=[job.pk], params={"project_id": self.project.pk}) - result_data = [ - { - "reply_subject": "test.reply.1", - "result": { - "pipeline": "test-pipeline", - "algorithms": {}, - "total_time": 1.5, - "source_images": [], - "detections": [], - "errors": None, - }, - } - ] + result_data = { + "results": [ + { + "reply_subject": "test.reply.1", + "result": { + "pipeline": "test-pipeline", + "algorithms": {}, + "total_time": 1.5, + "source_images": [], + "detections": [], + "errors": None, + }, + } + ] + } resp = self.client.post(result_url, result_data, format="json") @@ -572,16 +568,19 @@ def test_result_endpoint_validation(self): result_url = reverse_with_params("api:job-result", args=[job.pk], params={"project_id": self.project.pk}) # Test with missing reply_subject - invalid_data = [{"result": {"pipeline": "test"}}] + invalid_data = {"results": [{"result": {"pipeline": "test"}}]} resp = self.client.post(result_url, invalid_data, format="json") self.assertEqual(resp.status_code, 400) - self.assertIn("reply_subject", resp.json()[0].lower()) # Test with missing result - invalid_data = [{"reply_subject": "test.reply"}] + invalid_data = {"results": [{"reply_subject": "test.reply"}]} resp = self.client.post(result_url, invalid_data, format="json") self.assertEqual(resp.status_code, 400) - self.assertIn("result", resp.json()[0].lower()) + + # Test with bare list (no longer accepted) + bare_list = [{"reply_subject": "test.reply", "result": {"pipeline": "test"}}] + resp = self.client.post(result_url, bare_list, format="json") + self.assertEqual(resp.status_code, 400) class TestJobDispatchModeFiltering(APITestCase): @@ -722,9 +721,7 @@ def test_tasks_endpoint_rejects_non_async_jobs(self): ) self.client.force_authenticate(user=self.user) - tasks_url = reverse_with_params( - "api:job-tasks", args=[sync_job.pk], params={"project_id": self.project.pk, "batch": 1} - ) - resp = self.client.get(tasks_url) + tasks_url = reverse_with_params("api:job-tasks", args=[sync_job.pk], params={"project_id": self.project.pk}) + resp = self.client.post(tasks_url, {"batch_size": 1}, format="json") self.assertEqual(resp.status_code, 400) self.assertIn("async_api", resp.json()[0].lower()) diff --git a/ami/jobs/tests/test_tasks.py b/ami/jobs/tests/test_tasks.py index daf1b6ae6..d183dfb3c 100644 --- a/ami/jobs/tests/test_tasks.py +++ b/ami/jobs/tests/test_tasks.py @@ -384,16 +384,18 @@ def test_result_endpoint_with_error_result(self, mock_apply_async): self.client.force_authenticate(user=self.user) result_url = reverse_with_params("api:job-result", args=[self.job.pk], params={"project_id": self.project.pk}) - # Create error result data - result_data = [ - { - "reply_subject": "test.reply.error.1", - "result": { - "error": "Image processing timeout", - "image_id": str(self.image.pk), - }, - } - ] + # Create error result data (wrapped format) + result_data = { + "results": [ + { + "reply_subject": "test.reply.error.1", + "result": { + "error": "Image processing timeout", + "image_id": str(self.image.pk), + }, + } + ] + } # POST error result to API resp = self.client.post(result_url, result_data, format="json") diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 832e15f30..625fb8b47 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -1,8 +1,8 @@ import asyncio import logging +import kombu.exceptions import nats.errors -import pydantic from asgiref.sync import async_to_sync from django.db.models import Q from django.db.models.query import QuerySet @@ -17,11 +17,16 @@ from ami.base.permissions import ObjectPermission from ami.base.views import ProjectMixin -from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param +from ami.jobs.schemas import ids_only_param, incomplete_only_param +from ami.jobs.serializers import ( + MLJobResultsRequestSerializer, + MLJobResultsResponseSerializer, + MLJobTasksRequestSerializer, + MLJobTasksResponseSerializer, +) from ami.jobs.tasks import process_nats_pipeline_result from ami.main.api.schemas import project_id_doc_param from ami.main.api.views import DefaultViewSet -from ami.ml.schemas import PipelineTaskResult from ami.utils.fields import url_boolean_param from .models import Job, JobDispatchMode, JobState @@ -238,24 +243,25 @@ def list(self, request, *args, **kwargs): return super().list(request, *args, **kwargs) @extend_schema( - parameters=[batch_param], - responses={200: dict}, + request=MLJobTasksRequestSerializer, + responses={200: MLJobTasksResponseSerializer}, + parameters=[project_id_doc_param], ) - @action(detail=True, methods=["get"], name="tasks") + @action(detail=True, methods=["post"], name="tasks") def tasks(self, request, pk=None): """ - Get tasks from the job queue. + Fetch tasks from the job queue (POST). Returns task data with reply_subject for acknowledgment. External workers should: - 1. Call this endpoint to get tasks + 1. POST to this endpoint with {"batch_size": N} 2. Process the tasks - 3. POST to /jobs/{id}/result/ with the reply_subject to acknowledge + 3. POST to /jobs/{id}/result/ with the results """ + serializer = MLJobTasksRequestSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + batch_size = serializer.validated_data["batch_size"] + job: Job = self.get_object() - try: - batch = IntegerField(required=True, min_value=1).clean(request.query_params.get("batch")) - except Exception as e: - raise ValidationError({"batch": str(e)}) from e # Only async_api jobs have tasks fetchable from NATS if job.dispatch_mode != JobDispatchMode.ASYNC_API: @@ -277,7 +283,7 @@ def tasks(self, request, pk=None): async def get_tasks(): async with TaskQueueManager() as manager: - return [task.dict() for task in await manager.reserve_tasks(job.pk, count=batch, timeout=0.5)] + return [task.dict() for task in await manager.reserve_tasks(job.pk, count=batch_size, timeout=0.5)] try: tasks = async_to_sync(get_tasks)() @@ -287,14 +293,19 @@ async def get_tasks(): return Response({"tasks": tasks}) + @extend_schema( + request=MLJobResultsRequestSerializer, + responses={200: MLJobResultsResponseSerializer}, + parameters=[project_id_doc_param], + ) @action(detail=True, methods=["post"], name="result") def result(self, request, pk=None): """ - The request body should be a list of results: list[PipelineTaskResult] + Submit pipeline results. - This endpoint accepts a list of pipeline results and queues them for - background processing. Each result will be validated, saved to the database, - and acknowledged via NATS in a Celery task. + Accepts: {"results": [PipelineTaskResult, ...]} + + Results are validated then queued for background processing via Celery. """ job = self.get_object() @@ -302,20 +313,11 @@ def result(self, request, pk=None): # Record heartbeat for async processing services on this pipeline _mark_pipeline_pull_services_seen(job) - # Validate request data is a list - if isinstance(request.data, list): - results = request.data - else: - results = [request.data] + serializer = MLJobResultsRequestSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + validated_results = serializer.validated_data["results"] try: - # Pre-validate all results before enqueuing any tasks - # This prevents partial queueing and duplicate task processing - validated_results = [] - for item in results: - task_result = PipelineTaskResult(**item) - validated_results.append(task_result) - # All validation passed, now queue all tasks queued_tasks = [] for task_result in validated_results: @@ -337,27 +339,28 @@ def result(self, request, pk=None): ) logger.info( - f"Queued pipeline result processing for job {job.pk}, " - f"task_id: {task.id}, reply_subject: {reply_subject}" + "Queued pipeline result for job %s, task_id: %s, reply_subject: %s", + job.pk, + task.id, + reply_subject, ) return Response( { "status": "accepted", "job_id": job.pk, - "results_queued": len([t for t in queued_tasks if t["status"] == "queued"]), + "results_queued": len(queued_tasks), "tasks": queued_tasks, } ) - except pydantic.ValidationError as e: - raise ValidationError(f"Invalid result data: {e}") from e - except Exception as e: - logger.error(f"Failed to queue pipeline results for job {job.pk}: {e}") + except (OSError, kombu.exceptions.KombuError) as e: + logger.error("Failed to queue pipeline results for job %s: %s", job.pk, e) return Response( { "status": "error", "job_id": job.pk, + "detail": "Task queue temporarily unavailable", }, - status=500, + status=503, ) diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py index 7449c59e6..9322e4116 100644 --- a/ami/ml/schemas.py +++ b/ami/ml/schemas.py @@ -262,6 +262,24 @@ class PipelineProcessingTask(pydantic.BaseModel): # config: PipelineRequestConfigParameters | dict | None = None +class ProcessingServiceClientInfo(pydantic.BaseModel): + """Identity metadata sent by a processing service worker. + + A single ProcessingService record in the database may have multiple + physical workers, pods, or machines running simultaneously. This model + lets the server distinguish between them for logging, debugging, and + eventually for per-worker health tracking. + + Fields are intentionally left open for now. Processing services can + send any key-value pairs they find useful (e.g. hostname, pod_name, + software version). The schema will be tightened once real-world usage + patterns emerge. + """ + + class Config: + extra = "allow" + + class PipelineTaskResult(pydantic.BaseModel): """ The result from processing a single PipelineProcessingTask.