diff --git a/components/spider-core/src/types/io.rs b/components/spider-core/src/types/io.rs index 0186fb46..00c06057 100644 --- a/components/spider-core/src/types/io.rs +++ b/components/spider-core/src/types/io.rs @@ -6,7 +6,7 @@ use crate::{ }; /// Represents an input of a task. -#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub enum TaskInput { ValuePayload(Vec), } diff --git a/components/spider-storage/src/cache.rs b/components/spider-storage/src/cache.rs index 4ace32b7..d520f519 100644 --- a/components/spider-storage/src/cache.rs +++ b/components/spider-storage/src/cache.rs @@ -3,6 +3,7 @@ use spider_core::task::TaskIndex; pub mod error; pub mod io; pub mod job; +pub mod job_submission; mod sync; pub mod task; diff --git a/components/spider-storage/src/cache/error.rs b/components/spider-storage/src/cache/error.rs index aa1a24bd..29def958 100644 --- a/components/spider-storage/src/cache/error.rs +++ b/components/spider-storage/src/cache/error.rs @@ -41,8 +41,11 @@ pub enum InternalError { #[error("task graph corrupted: {0}")] TaskGraphCorrupted(String), - #[error("task graph input size mismatch: expected {0}, got {1}")] - TaskGraphInputsSizeMismatch(usize, usize), + #[error("task graph must contain at least one task")] + TaskGraphEmpty, + + #[error("task graph input size mismatch: expected {expected}, got {actual}")] + TaskGraphInputSizeMismatch { expected: usize, actual: usize }, #[error("job not started")] JobNotStarted, diff --git a/components/spider-storage/src/cache/job.rs b/components/spider-storage/src/cache/job.rs index d2b8e20c..5c575e8e 100644 --- a/components/spider-storage/src/cache/job.rs +++ b/components/spider-storage/src/cache/job.rs @@ -8,10 +8,10 @@ use std::{ use spider_core::{ job::JobState, - task::{TaskGraph as SubmittedTaskGraph, TaskIndex, TaskState}, + task::{TaskIndex, TaskState}, types::{ id::{ExecutionManagerId, JobId, ResourceGroupId, TaskInstanceId}, - io::{ExecutionContext, TaskInput, TaskOutput}, + io::{ExecutionContext, TaskOutput}, }, }; use tokio::sync::{RwLockReadGuard, RwLockWriteGuard}; @@ -20,6 +20,7 @@ use crate::{ cache::{ TaskId, error::{CacheError, InternalError, InternalError::UnexpectedJobState, StaleStateError}, + job_submission::ValidatedJobSubmission, task::TaskGraph, }, db::InternalJobOrchestration, @@ -63,27 +64,17 @@ impl< /// /// Returns an error if: /// - /// * [`InternalError::TaskGraphCorrupted`] if the given task graph doesn't contain any tasks. - /// The current version of JCB requires the job contains at least one task. /// * Forwards [`TaskGraph::create`]'s return values on failure. pub async fn create( id: JobId, owner_id: ResourceGroupId, - submitted_task_graph: &SubmittedTaskGraph, - inputs: Vec, + job_submission: ValidatedJobSubmission, ready_queue_sender: ReadyQueueSenderType, db_connector: DbConnectorType, task_instance_pool_connector: TaskInstancePoolConnectorType, ) -> Result { - let num_tasks = submitted_task_graph.get_num_tasks(); - if 0 == num_tasks { - return Err(InternalError::TaskGraphCorrupted( - "task graph with no task is unsupported".to_owned(), - ) - .into()); - } - - let task_graph = TaskGraph::create(submitted_task_graph, inputs).await?; + let num_tasks = job_submission.task_graph().get_num_tasks(); + let task_graph = TaskGraph::create(job_submission).await?; let job_execution_state = JobExecutionState { state: JobState::Ready, task_graph, diff --git a/components/spider-storage/src/cache/job_submission.rs b/components/spider-storage/src/cache/job_submission.rs new file mode 100644 index 00000000..ca444b7e --- /dev/null +++ b/components/spider-storage/src/cache/job_submission.rs @@ -0,0 +1,159 @@ +use spider_core::{task::TaskGraph, types::io::TaskInput}; + +use super::error::InternalError; + +/// A validated wrapper around a task graph and its corresponding job inputs. +/// +/// This type guarantees at construction time that: +/// +/// * The task graph contains at least one task. +/// * The number of job inputs matches the number of graph inputs expected by the task graph. +/// +/// By passing this type through the call chain, downstream consumers can trust the consistency +/// invariant without re-validating. +#[derive(Debug)] +pub struct ValidatedJobSubmission { + task_graph: TaskGraph, + inputs: Vec, +} + +impl ValidatedJobSubmission { + /// Creates a new validated job submission. + /// + /// # Returns + /// + /// The validated job submission on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`InternalError::TaskGraphEmpty`] if the task graph contains no tasks. + /// * [`InternalError::TaskGraphInputSizeMismatch`] if the number of inputs does not match the + /// number of graph inputs. + pub fn create(task_graph: TaskGraph, inputs: Vec) -> Result { + let num_tasks = task_graph.get_num_tasks(); + if num_tasks == 0 { + return Err(InternalError::TaskGraphEmpty); + } + let expected_num_inputs = task_graph.get_task_graph_input_indices().len(); + let actual_num_inputs = inputs.len(); + if expected_num_inputs != actual_num_inputs { + return Err(InternalError::TaskGraphInputSizeMismatch { + expected: expected_num_inputs, + actual: actual_num_inputs, + }); + } + Ok(Self { task_graph, inputs }) + } + + /// # Returns + /// + /// A reference to the validated task graph. + #[must_use] + pub const fn task_graph(&self) -> &TaskGraph { + &self.task_graph + } + + /// # Returns + /// + /// A reference to the validated job inputs. + #[must_use] + pub fn inputs(&self) -> &[TaskInput] { + &self.inputs + } + + /// Consumes the wrapper and returns the owned task graph and job inputs. + /// + /// # Returns + /// + /// A tuple of `(task_graph, inputs)`. + #[must_use] + pub fn into_parts(self) -> (TaskGraph, Vec) { + (self.task_graph, self.inputs) + } +} + +#[cfg(test)] +mod tests { + use spider_core::{ + task::{ + DataTypeDescriptor, + ExecutionPolicy, + TaskDescriptor, + TaskGraph as SubmittedTaskGraph, + TdlContext, + ValueTypeDescriptor, + }, + types::io::TaskInput, + }; + + use super::{super::error::InternalError, *}; + + fn create_single_input_task_graph() -> SubmittedTaskGraph { + let bytes_type = DataTypeDescriptor::Value(ValueTypeDescriptor::bytes()); + let mut graph = + SubmittedTaskGraph::new(None, None).expect("task graph creation should succeed"); + graph + .insert_task(TaskDescriptor { + tdl_context: TdlContext { + package: "test_pkg".to_owned(), + task_func: "test_fn".to_owned(), + }, + execution_policy: Some(ExecutionPolicy::default()), + inputs: vec![bytes_type], + outputs: vec![], + input_sources: None, + }) + .expect("task insertion should succeed"); + graph + } + + #[test] + fn valid_job_submission_succeeds() { + let graph = create_single_input_task_graph(); + let inputs = vec![TaskInput::ValuePayload(vec![1u8; 4])]; + let result = ValidatedJobSubmission::create(graph, inputs); + assert!(result.is_ok(), "valid submission should succeed"); + } + + #[test] + fn empty_task_graph_fails() { + let graph = + SubmittedTaskGraph::new(None, None).expect("task graph creation should succeed"); + let inputs = vec![]; + let result = ValidatedJobSubmission::create(graph, inputs); + assert!( + matches!(result, Err(InternalError::TaskGraphEmpty)), + "empty task graph should return EmptyTaskGraph" + ); + } + + #[test] + fn mismatched_input_count_fails() { + let graph = create_single_input_task_graph(); + let inputs = vec![]; + let result = ValidatedJobSubmission::create(graph, inputs); + assert!( + matches!( + result, + Err(InternalError::TaskGraphInputSizeMismatch { + expected: 1, + actual: 0 + }) + ), + "mismatched input count should return TaskGraphInputSizeMismatch" + ); + } + + #[test] + fn into_parts_returns_owned_components() { + let graph = create_single_input_task_graph(); + let inputs = vec![TaskInput::ValuePayload(vec![1u8; 4])]; + let submission = + ValidatedJobSubmission::create(graph, inputs).expect("submission should be valid"); + let (graph, inputs) = submission.into_parts(); + assert_eq!(graph.get_num_tasks(), 1, "task graph should have 1 task"); + assert_eq!(inputs.len(), 1, "should have 1 input"); + } +} diff --git a/components/spider-storage/src/cache/task.rs b/components/spider-storage/src/cache/task.rs index 63ae1a9b..416a3264 100644 --- a/components/spider-storage/src/cache/task.rs +++ b/components/spider-storage/src/cache/task.rs @@ -5,15 +5,7 @@ use std::{ }; use spider_core::{ - task::{ - Task, - TaskGraph as SubmittedTaskGraph, - TaskIndex, - TaskState, - TdlContext, - TerminationTaskDescriptor, - TimeoutPolicy, - }, + task::{Task, TaskIndex, TaskState, TdlContext, TerminationTaskDescriptor, TimeoutPolicy}, types::{ id::TaskInstanceId, io::{ExecutionContext, TaskInput, TaskOutput}, @@ -24,6 +16,7 @@ use tokio::sync::RwLock; use crate::cache::{ error::{CacheError, InternalError, StaleStateError}, io::{InputReader, OutputReader, OutputWriter, ValuePayload}, + job_submission::ValidatedJobSubmission, sync::{Reader, SharedRw, Writer}, }; @@ -38,7 +31,7 @@ pub struct TaskGraph { impl TaskGraph { /// Factory function. /// - /// Creates a new task graph from a submitted task graph and the input task inputs. + /// Creates a new task graph from a validated job submission. /// /// # Returns /// @@ -51,28 +44,18 @@ impl TaskGraph { /// * [`InternalError::TaskGraphCorrupted`] if: /// * Any dataflow deps' index is out-of-range. /// * Any task index is out-of-range. - /// * [`InternalError::TaskGraphInputsSizeMismatch`] if the number of provided inputs does not - /// match the task graph’s expected number of inputs. /// * Forwards [`SharedTaskControlBlock::create`]'s return values on failure. /// /// # Panics /// /// Panics if the internal TCB buffer is corrupted. - pub async fn create( - submitted_task_graph: &SubmittedTaskGraph, - inputs: Vec, - ) -> Result { + pub async fn create(job_submission: ValidatedJobSubmission) -> Result { + let (submitted_task_graph, inputs) = job_submission.into_parts(); let dataflow_dep_buffer: Vec> = (0..submitted_task_graph .get_num_dataflow_deps()) .map(|_| SharedRw::new(RwLock::new(ValuePayload::default()))) .collect(); let task_graph_input_indices = submitted_task_graph.get_task_graph_input_indices(); - if inputs.len() != task_graph_input_indices.len() { - return Err(InternalError::TaskGraphInputsSizeMismatch( - task_graph_input_indices.len(), - inputs.len(), - )); - } for (deps_index, input) in task_graph_input_indices.into_iter().zip(inputs) { let dataflow_dep = dataflow_dep_buffer.get(deps_index).ok_or_else(|| { InternalError::TaskGraphCorrupted( @@ -949,6 +932,7 @@ mod tests { }; use super::*; + use crate::cache::job_submission::ValidatedJobSubmission; /// # Returns /// @@ -1092,7 +1076,9 @@ mod tests { let inputs: Vec = (0..num_inputs) .map(|_| TaskInput::ValuePayload(vec![0u8; 4])) .collect(); - TaskGraph::create(&submitted, inputs) + let job_submission = ValidatedJobSubmission::create(submitted, inputs) + .expect("job submission should be valid"); + TaskGraph::create(job_submission) .await .expect("cache task graph creation should succeed") } @@ -1107,7 +1093,7 @@ mod tests { max_num_instances: u32, max_num_retry: u32, ) -> SharedTerminationTaskControlBlock { - let submitted = SubmittedTaskGraph::new( + let mut submitted = SubmittedTaskGraph::new( Some(TerminationTaskDescriptor { tdl_context: TdlContext { package: "test_pkg".to_owned(), @@ -1122,7 +1108,21 @@ mod tests { None, ) .expect("task graph with commit task should be created"); - let task_graph = TaskGraph::create(&submitted, vec![]) + submitted + .insert_task(TaskDescriptor { + tdl_context: TdlContext { + package: "test_pkg".to_owned(), + task_func: "dummy_fn".to_owned(), + }, + execution_policy: Some(ExecutionPolicy::default()), + inputs: vec![], + outputs: vec![], + input_sources: None, + }) + .expect("task insertion should succeed"); + let job_submission = ValidatedJobSubmission::create(submitted, vec![]) + .expect("job submission should be valid"); + let task_graph = TaskGraph::create(job_submission) .await .expect("cache task graph creation should succeed"); task_graph @@ -1237,7 +1237,9 @@ mod tests { TaskInput::ValuePayload(input_a), TaskInput::ValuePayload(input_b), ]; - TaskGraph::create(&submitted, inputs) + let job_submission = ValidatedJobSubmission::create(submitted, inputs) + .expect("job submission should be valid"); + TaskGraph::create(job_submission) .await .expect("cache task graph creation should succeed") } diff --git a/components/spider-storage/src/db/mariadb.rs b/components/spider-storage/src/db/mariadb.rs index 1005e9c1..faeda2a6 100644 --- a/components/spider-storage/src/db/mariadb.rs +++ b/components/spider-storage/src/db/mariadb.rs @@ -5,16 +5,16 @@ use const_format::formatcp; use secrecy::ExposeSecret; use spider_core::{ job::JobState, - task::TaskGraph, types::{ id::{ExecutionManagerId, JobId, ResourceGroupId, SessionId}, - io::{TaskInput, TaskOutput}, + io::TaskOutput, }, }; use spider_derive::MySqlEnum; use sqlx::{MySqlPool, mysql::MySqlDatabaseError}; use crate::{ + cache::job_submission::ValidatedJobSubmission, config::DatabaseConfig, db::{ DbError, @@ -98,8 +98,7 @@ impl ExternalJobOrchestration for MariaDbStorageConnector { async fn register( &self, resource_group_id: ResourceGroupId, - task_graph: &TaskGraph, - job_inputs: &[TaskInput], + job_submission: &ValidatedJobSubmission, ) -> Result { const INSERT_QUERY: &str = formatcp!( "INSERT INTO `{table}` (`resource_group_id`, `serialized_task_graph`, \ @@ -107,6 +106,8 @@ impl ExternalJobOrchestration for MariaDbStorageConnector { table = JOBS_TABLE_NAME, ); + let task_graph = job_submission.task_graph(); + let job_inputs = job_submission.inputs(); let serialized_task_graph = task_graph .to_json() .map_err(|e| DbError::TaskGraphSerializationFailure(Box::new(e)))?; diff --git a/components/spider-storage/src/db/protocol.rs b/components/spider-storage/src/db/protocol.rs index 89ba3d9c..0b9e297f 100644 --- a/components/spider-storage/src/db/protocol.rs +++ b/components/spider-storage/src/db/protocol.rs @@ -3,14 +3,13 @@ use std::net::IpAddr; use async_trait::async_trait; use spider_core::{ job::JobState, - task::TaskGraph, types::{ id::{ExecutionManagerId, JobId, ResourceGroupId, SessionId}, - io::{TaskInput, TaskOutput}, + io::TaskOutput, }, }; -use crate::db::error::DbError; +use crate::{cache::job_submission::ValidatedJobSubmission, db::error::DbError}; /// The database storage interface. A database storage must implement the following traits: /// @@ -36,8 +35,7 @@ pub trait ExternalJobOrchestration { /// # Parameters /// /// * `resource_group_id` - The owner of the created job. - /// * `task_graph` - The task graph representing the job's tasks and their dependencies. - /// * `job_inputs` - A slice of job inputs required for the job. + /// * `job_submission` - The validated job submission containing the task graph and job inputs. /// /// # Returns /// @@ -48,20 +46,13 @@ pub trait ExternalJobOrchestration { /// Returns an error if: /// /// * [`DbError::ResourceGroupNotFound`] if the `resource_group_id` does not exist. - /// * [`DbError::TaskGraphSerializationFailure`] if the `task_graph` serialization fails. - /// * [`DbError::ValueSerializationFailure`] if the `job_inputs` serialization fails. + /// * [`DbError::TaskGraphSerializationFailure`] if the task graph serialization fails. + /// * [`DbError::ValueSerializationFailure`] if the job inputs serialization fails. /// * Forwards [`sqlx::error::Error`] on DB operation failure. - /// - /// # Note - /// - /// This function assumes that the `task_graph` and `job_inputs` are consistent. - /// - /// TODO: Fix this when #284 is addressed. async fn register( &self, resource_group_id: ResourceGroupId, - task_graph: &TaskGraph, - job_inputs: &[TaskInput], + job_submission: &ValidatedJobSubmission, ) -> Result; /// Gets the state of a job. diff --git a/components/spider-storage/src/state/job_cache.rs b/components/spider-storage/src/state/job_cache.rs index 40bb64c8..6ad3c7ce 100644 --- a/components/spider-storage/src/state/job_cache.rs +++ b/components/spider-storage/src/state/job_cache.rs @@ -151,6 +151,7 @@ mod tests { cache::{ error::InternalError, job::SharedJobControlBlock, + job_submission::ValidatedJobSubmission, task::{SharedTaskControlBlock, SharedTerminationTaskControlBlock}, }, db::DbError, @@ -276,11 +277,13 @@ mod tests { }) .expect("task insertion should succeed"); + let job_submission = + ValidatedJobSubmission::create(submitted, vec![TaskInput::ValuePayload(vec![0u8; 4])]) + .expect("job submission should be valid"); SharedJobControlBlock::create( job_id, spider_core::types::id::ResourceGroupId::new(), - &submitted, - vec![TaskInput::ValuePayload(vec![0u8; 4])], + job_submission, MockReadyQueueSender, MockDbConnector, MockTaskInstancePoolConnector, @@ -454,11 +457,13 @@ mod tests { .expect("task insertion should succeed"); let job_id = JobId::new(); + let job_submission = + ValidatedJobSubmission::create(submitted, vec![TaskInput::ValuePayload(vec![0u8; 4])]) + .expect("job submission should be valid"); let jcb = SharedJobControlBlock::create( job_id, spider_core::types::id::ResourceGroupId::new(), - &submitted, - vec![TaskInput::ValuePayload(vec![0u8; 4])], + job_submission, sender, MockDbConnector, MockTaskInstancePoolConnector, diff --git a/components/spider-storage/src/task_instance_pool.rs b/components/spider-storage/src/task_instance_pool.rs index b5779ff0..ace45ce6 100644 --- a/components/spider-storage/src/task_instance_pool.rs +++ b/components/spider-storage/src/task_instance_pool.rs @@ -540,6 +540,7 @@ mod tests { use tokio::sync::Mutex; use super::*; + use crate::cache::job_submission::ValidatedJobSubmission; const DEFAULT_CHANNEL_SIZE: usize = 128; @@ -661,12 +662,12 @@ mod tests { input_sources: None, }) .expect("task insertion should succeed"); - let task_graph = crate::cache::task::TaskGraph::create( - &submitted, - vec![TaskInput::ValuePayload(vec![0u8; 4])], - ) - .await - .expect("cache task graph creation should succeed"); + let job_submission = + ValidatedJobSubmission::create(submitted, vec![TaskInput::ValuePayload(vec![0u8; 4])]) + .expect("job submission should be valid"); + let task_graph = crate::cache::task::TaskGraph::create(job_submission) + .await + .expect("cache task graph creation should succeed"); task_graph .get_task_control_block(0) .expect("task control block should exist") diff --git a/components/spider-storage/tests/jcb_test.rs b/components/spider-storage/tests/jcb_test.rs index 314f574a..6f444343 100644 --- a/components/spider-storage/tests/jcb_test.rs +++ b/components/spider-storage/tests/jcb_test.rs @@ -1,5 +1,8 @@ use spider_core::job::JobState; -use spider_storage::db::{ExternalJobOrchestration, InternalJobOrchestration}; +use spider_storage::{ + cache::job_submission::ValidatedJobSubmission, + db::{ExternalJobOrchestration, InternalJobOrchestration}, +}; use super::{ scheduling_infra::{ @@ -47,9 +50,10 @@ async fn test_flat_success( ) -> WorkloadResult { let (graph, inputs) = build_flat_task_graph(10_000, 1024, true, true); let num_tasks = graph.get_num_tasks(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let result = run_workload( - &graph, - inputs, + job_submission, db_connector_factory, CancelPolicy::Never, default_output_handler(1024), @@ -89,9 +93,10 @@ async fn test_flat_cancel( db_connector_factory: impl DbConnectorFactory, ) -> WorkloadResult { let (graph, inputs) = build_flat_task_graph(10_000, 1024, true, true); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let result = run_workload( - &graph, - inputs, + job_submission, db_connector_factory, CancelPolicy::Immediate, default_output_handler(1024), @@ -133,9 +138,10 @@ async fn test_neural_net_success WorkloadResult { let (graph, inputs) = build_neural_net_task_graph(); let num_tasks = graph.get_num_tasks(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let result = run_workload( - &graph, - inputs, + job_submission, db_connector_factory, CancelPolicy::Never, default_output_handler(128), @@ -178,9 +184,10 @@ async fn test_neural_net_cancel, ) -> WorkloadResult { let (graph, inputs) = build_neural_net_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let result = run_workload( - &graph, - inputs, + job_submission, db_connector_factory, CancelPolicy::Immediate, default_output_handler(128), @@ -220,9 +227,10 @@ async fn test_always_fail_terminates_job, ) -> WorkloadResult { let (graph, inputs) = build_flat_task_graph(3, 128, false, false); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let result = run_workload( - &graph, - inputs, + job_submission, db_connector_factory, CancelPolicy::Never, default_output_handler(128), @@ -258,9 +266,10 @@ async fn test_concurrent_success_and_cancel, ) -> WorkloadResult { let (graph, inputs) = build_flat_task_graph(100, 128, true, true); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let result = run_workload( - &graph, - inputs, + job_submission, db_connector_factory, CancelPolicy::Concurrent, default_output_handler(128), diff --git a/components/spider-storage/tests/mariadb_test.rs b/components/spider-storage/tests/mariadb_test.rs index 05c9742f..3b90ab07 100644 --- a/components/spider-storage/tests/mariadb_test.rs +++ b/components/spider-storage/tests/mariadb_test.rs @@ -10,14 +10,17 @@ use spider_core::{ io::TaskInput, }, }; -use spider_storage::db::{ - DbError, - ExecutionManagerLivenessManagement, - ExternalJobOrchestration, - InternalJobOrchestration, - MariaDbStorageConnector, - ResourceGroupManagement, - SessionManagement, +use spider_storage::{ + cache::job_submission::ValidatedJobSubmission, + db::{ + DbError, + ExecutionManagerLivenessManagement, + ExternalJobOrchestration, + InternalJobOrchestration, + MariaDbStorageConnector, + ResourceGroupManagement, + SessionManagement, + }, }; use super::{ @@ -58,9 +61,11 @@ async fn test_register_job() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -77,10 +82,10 @@ async fn test_register_job_invalid_resource_group() { let storage = create_mariadb_connector().await; let fake_rg_id = ResourceGroupId::new(); let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); - let result = storage - .register(fake_rg_id, &graph, inputs.as_slice()) - .await; + let result = storage.register(fake_rg_id, &job_submission).await; assert!( matches!(result, Err(DbError::ResourceGroupNotFound(_))), @@ -94,9 +99,11 @@ async fn test_start_job() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -115,9 +122,11 @@ async fn test_start_job_wrong_state() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -136,9 +145,11 @@ async fn test_cancel_job_without_cleanup_transitions_to_cancelled() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -162,9 +173,11 @@ async fn test_get_outputs_succeeded_job() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -188,9 +201,11 @@ async fn test_get_outputs_wrong_state() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -207,9 +222,11 @@ async fn test_get_error_failed_job() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -232,9 +249,11 @@ async fn test_get_error_wrong_state() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -251,9 +270,11 @@ async fn test_cancel_job_with_cleanup_transitions_to_cleanup_ready() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -276,9 +297,11 @@ async fn test_cancel_already_terminal() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -304,9 +327,11 @@ async fn test_set_state_valid_transition() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -327,9 +352,11 @@ async fn test_set_state_invalid_transition() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -350,9 +377,11 @@ async fn test_commit_outputs_without_commit_task() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -375,9 +404,11 @@ async fn test_commit_outputs_with_commit_task() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -403,9 +434,11 @@ async fn test_commit_outputs_wrong_state() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -427,9 +460,11 @@ async fn test_fail_job() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -452,9 +487,11 @@ async fn test_fail_terminal_state() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -481,9 +518,11 @@ async fn test_delete_expired_terminated_jobs() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -694,9 +733,11 @@ async fn test_cancel_from_ready_state() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); @@ -718,9 +759,11 @@ async fn test_delete_expired_terminated_jobs_no_match() { let storage = create_mariadb_connector().await; let rg_id = create_test_resource_group(&storage).await; let (graph, inputs) = single_task_graph(); + let job_submission = + ValidatedJobSubmission::create(graph, inputs).expect("job submission should be valid"); let job_id = storage - .register(rg_id, &graph, inputs.as_slice()) + .register(rg_id, &job_submission) .await .expect("register should succeed"); diff --git a/components/spider-storage/tests/scheduling_infra.rs b/components/spider-storage/tests/scheduling_infra.rs index 78fbd7d0..d3e5eb98 100644 --- a/components/spider-storage/tests/scheduling_infra.rs +++ b/components/spider-storage/tests/scheduling_infra.rs @@ -85,10 +85,10 @@ use dashmap::DashMap; use rand::{Rng, SeedableRng}; use spider_core::{ job::JobState, - task::{TaskGraph as SubmittedTaskGraph, TaskIndex}, + task::TaskIndex, types::{ id::{ExecutionManagerId, JobId, ResourceGroupId, TaskInstanceId}, - io::{ExecutionContext, TaskInput, TaskOutput}, + io::{ExecutionContext, TaskOutput}, }, }; use spider_storage::{ @@ -96,6 +96,7 @@ use spider_storage::{ TaskId, error::{CacheError, InternalError}, job::SharedJobControlBlock, + job_submission::ValidatedJobSubmission, task::{SharedTaskControlBlock, SharedTerminationTaskControlBlock}, }, db::{DbError, ExternalJobOrchestration, InternalJobOrchestration, MariaDbStorageConnector}, @@ -205,25 +206,23 @@ pub type FactoryReturn = (DbConnectorType, JobId, ResourceGroup /// /// * `DbConnectorType` - The DB-layer connector implementation. /// -/// Receives the submitted task graph and job inputs, performs any required DB setup (e.g. job -/// registration), and returns the connector along with the [`JobId`] and [`ResourceGroupId`] to -/// use for the JCB. +/// Receives the validated job submission, performs any required DB setup (e.g. job registration), +/// and returns the connector along with the [`JobId`] and [`ResourceGroupId`] to use for the JCB. pub trait DbConnectorFactory: - AsyncFnOnce(&SubmittedTaskGraph, &[TaskInput]) -> FactoryReturn + Send { + AsyncFnOnce(&ValidatedJobSubmission) -> FactoryReturn + Send { } impl DbConnectorFactory for AsyncFunc where - AsyncFunc: - AsyncFnOnce(&SubmittedTaskGraph, &[TaskInput]) -> FactoryReturn + Send, + AsyncFunc: AsyncFnOnce(&ValidatedJobSubmission) -> FactoryReturn + Send, { } /// Creates a [`NoopDbConnector`] with default [`JobId`] and [`ResourceGroupId`]. #[must_use] pub fn noop_db_connector_factory() -> impl DbConnectorFactory { - async |_, _| { + async |_: &ValidatedJobSubmission| { ( NoopDbConnector {}, JobId::default(), @@ -317,8 +316,7 @@ pub fn write_instrument_results( /// A [`WorkloadResult`] containing the terminal state and commit/cleanup execution counts. #[allow(clippy::too_many_lines)] pub async fn run_workload( - submitted_task_graph: &SubmittedTaskGraph, - inputs: Vec, + job_submission: ValidatedJobSubmission, db_connector_factory: impl DbConnectorFactory, cancel_policy: CancelPolicy, output_handler: TaskOutputHandler, @@ -330,16 +328,14 @@ pub async fn run_workload( let ready_queue_sender = MockReadyQueueSender { sender: ready_sender, }; - let (db_connector, job_id, resource_group_id) = - db_connector_factory(submitted_task_graph, &inputs).await; + let (db_connector, job_id, resource_group_id) = db_connector_factory(&job_submission).await; let task_instance_pool = MockTaskInstancePool::new(); // Create and start the JCB. let inner_jcb = SharedJobControlBlock::create( job_id, resource_group_id, - submitted_task_graph, - inputs, + job_submission, ready_queue_sender, db_connector, task_instance_pool, @@ -457,8 +453,8 @@ pub fn mariadb_db_connector_factory( storage: MariaDbStorageConnector, rg_id: ResourceGroupId, ) -> impl DbConnectorFactory { - async move |graph, inputs| { - let job_id = ExternalJobOrchestration::register(&storage, rg_id, graph, inputs) + async move |job_submission: &ValidatedJobSubmission| { + let job_id = ExternalJobOrchestration::register(&storage, rg_id, job_submission) .await .expect("register should succeed"); (storage, job_id, rg_id)