diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 335d00477..1c036b859 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -18,6 +18,7 @@ PlanModel, PythonEnvironmentResponse, SourceInfo, + TaskParamsValidationRequest, TaskRequest, WorkerTask, ) @@ -172,6 +173,22 @@ def submit_task( return worker().submit_task(task) +def validate_task_params( + task_request: TaskParamsValidationRequest, metadata: dict[str, Any] | None = None +) -> bool: + """Validate the params for a task""" + # Can't default arg to mutable data structure: + if metadata is None: + metadata = {} + + task = Task( + name=task_request.name, + params=task_request.params, + metadata=metadata, + ) + return worker().validate_task_params(task) + + def clear_task(task_id: str) -> str: """Remove a task from the worker""" return worker().clear_task(task_id) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index ed09f9c56..4be5063e8 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -51,9 +51,11 @@ PythonEnvironmentResponse, SourceInfo, StateChangeRequest, + TaskParamsValidationRequest, TaskRequest, TaskResponse, TasksListResponse, + TasksParamValidationResponse, WorkerTask, ) from .runner import WorkerDispatcher @@ -327,6 +329,64 @@ def submit_task( ) from e +example_task_validate_params_request = TaskParamsValidationRequest( + name="count", + params={"detectors": ["x"]}, +) + + +@secure_router_v1.post( + "/validateTaskParams", status_code=status.HTTP_200_OK, tags=[Tag.TASK] +) +@start_as_current_span( + TRACER, + "request", + "task_request.name", + "task_request.params", +) +def validate_task_params( + request: Request, + response: Response, + task_request: Annotated[ + TaskParamsValidationRequest, Body(..., examples=[example_task_request]) + ], + runner: Annotated[WorkerDispatcher, Depends(_runner)], +) -> TasksParamValidationResponse: + """Validate the tasks parameters.""" + try: + # Extract user from jwt if using OIDC (if jwt exists) + access_token: dict[str, Any] | None = getattr( + request.state, "decoded_access_token", None + ) + if access_token: + user: str = access_token.get("fedid", "Unknown") + else: + user = "Unknown" + + validated: bool = runner.run( + interface.validate_task_params, task_request, {"user": user} + ) + return TasksParamValidationResponse(validated=validated) + except ValidationError as e: + # Add body/params context to location and ensure that all required + # fields defined in the generated schema are present + errors = [ + { + "loc": ["body", "params", *err.get("loc", [])], + "msg": err.get("msg", None), + "type": err.get("type", None), + # Input is not listed as required but is useful to have if available + "input": err.get("input", None), + } + for err in e.errors() + ] + + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=errors, + ) from e + + @secure_router_v1.delete( "/tasks/{task_id}", status_code=status.HTTP_200_OK, tags=[Tag.TASK] ) diff --git a/src/blueapi/service/model.py b/src/blueapi/service/model.py index dbaa1d965..0f07c591b 100644 --- a/src/blueapi/service/model.py +++ b/src/blueapi/service/model.py @@ -58,6 +58,16 @@ class TasksListResponse(BlueapiBaseModel): tasks: list[TrackableTask] = Field(description="List of tasks") +class TasksParamValidationResponse(BlueapiBaseModel): + """ + Diagnostic information on the tasks + """ + + validated: bool = Field( + description="Whether the task params were sucessfully validated" + ) + + class TaskRequest(BlueapiBaseModel): """ Request to run a task with related info @@ -72,6 +82,17 @@ class TaskRequest(BlueapiBaseModel): ) +class TaskParamsValidationRequest(BlueapiBaseModel): + """ + Request to validate the parameters of a task + """ + + name: str = Field(description="Name of plan to run") + params: Mapping[str, Any] = Field( + description="Values for parameters to plan, if any", default_factory=dict + ) + + class DeviceRequest(BlueapiBaseModel): """ A query for devices diff --git a/src/blueapi/worker/task_worker.py b/src/blueapi/worker/task_worker.py index caa39fe7a..672646631 100644 --- a/src/blueapi/worker/task_worker.py +++ b/src/blueapi/worker/task_worker.py @@ -292,6 +292,18 @@ def submit_task(self, task: Task) -> str: self._pending_tasks[task_id] = trackable_task return task_id + @start_as_current_span(TRACER, "task.name", "task.params") + def validate_task_params(self, task: Task) -> bool: + """ + Validates the params for a task + Args: + task: A description of the task + Returns: + bool: True of the params are validated + """ + task.prepare_params(self._ctx) # Will raise if parameters are invalid + return True + @start_as_current_span( TRACER, "trackable_task.task_id",