Skip to content
2 changes: 1 addition & 1 deletion components/spider-core/src/types/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
};

/// Represents an input of a task.
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub enum TaskInput {
ValuePayload(Vec<u8>),
}
Expand Down
1 change: 1 addition & 0 deletions components/spider-storage/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use spider_core::task::TaskIndex;
pub mod error;
pub mod io;
pub mod job;
pub mod job_submission;
mod sync;
pub mod task;

Expand Down
7 changes: 5 additions & 2 deletions components/spider-storage/src/cache/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ pub enum InternalError {
#[error("task graph corrupted: {0}")]
TaskGraphCorrupted(String),

#[error("task graph input size mismatch: expected {0}, got {1}")]
TaskGraphInputsSizeMismatch(usize, usize),
#[error("task graph must contain at least one task")]
TaskGraphEmpty,

#[error("task graph input size mismatch: expected {expected}, got {actual}")]
TaskGraphInputSizeMismatch { expected: usize, actual: usize },

#[error("job not started")]
JobNotStarted,
Expand Down
21 changes: 6 additions & 15 deletions components/spider-storage/src/cache/job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ use std::{

use spider_core::{
job::JobState,
task::{TaskGraph as SubmittedTaskGraph, TaskIndex, TaskState},
task::{TaskIndex, TaskState},
types::{
id::{ExecutionManagerId, JobId, ResourceGroupId, TaskInstanceId},
io::{ExecutionContext, TaskInput, TaskOutput},
io::{ExecutionContext, TaskOutput},
},
};
use tokio::sync::{RwLockReadGuard, RwLockWriteGuard};
Expand All @@ -20,6 +20,7 @@ use crate::{
cache::{
TaskId,
error::{CacheError, InternalError, InternalError::UnexpectedJobState, StaleStateError},
job_submission::ValidatedJobSubmission,
task::TaskGraph,
},
db::InternalJobOrchestration,
Expand Down Expand Up @@ -63,27 +64,17 @@ impl<
///
/// Returns an error if:
///
/// * [`InternalError::TaskGraphCorrupted`] if the given task graph doesn't contain any tasks.
/// The current version of JCB requires the job contains at least one task.
/// * Forwards [`TaskGraph::create`]'s return values on failure.
pub async fn create(
id: JobId,
owner_id: ResourceGroupId,
submitted_task_graph: &SubmittedTaskGraph,
inputs: Vec<TaskInput>,
job_submission: ValidatedJobSubmission,
ready_queue_sender: ReadyQueueSenderType,
db_connector: DbConnectorType,
task_instance_pool_connector: TaskInstancePoolConnectorType,
) -> Result<Self, CacheError> {
let num_tasks = submitted_task_graph.get_num_tasks();
if 0 == num_tasks {
return Err(InternalError::TaskGraphCorrupted(
"task graph with no task is unsupported".to_owned(),
)
.into());
}

let task_graph = TaskGraph::create(submitted_task_graph, inputs).await?;
let num_tasks = job_submission.task_graph().get_num_tasks();
let task_graph = TaskGraph::create(job_submission).await?;
let job_execution_state = JobExecutionState {
state: JobState::Ready,
task_graph,
Expand Down
159 changes: 159 additions & 0 deletions components/spider-storage/src/cache/job_submission.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
use spider_core::{task::TaskGraph, types::io::TaskInput};

use super::error::InternalError;

/// A validated wrapper around a task graph and its corresponding job inputs.
///
/// This type guarantees at construction time that:
///
/// * The task graph contains at least one task.
/// * The number of job inputs matches the number of graph inputs expected by the task graph.
///
/// By passing this type through the call chain, downstream consumers can trust the consistency
/// invariant without re-validating.
#[derive(Debug)]
pub struct ValidatedJobSubmission {
task_graph: TaskGraph,
inputs: Vec<TaskInput>,
}

impl ValidatedJobSubmission {
/// Creates a new validated job submission.
///
/// # Returns
///
/// The validated job submission on success.
///
/// # Errors
///
/// Returns an error if:
///
/// * [`InternalError::TaskGraphEmpty`] if the task graph contains no tasks.
/// * [`InternalError::TaskGraphInputSizeMismatch`] if the number of inputs does not match the
/// number of graph inputs.
pub fn validate(task_graph: TaskGraph, inputs: Vec<TaskInput>) -> Result<Self, InternalError> {
Comment thread
sitaowang1998 marked this conversation as resolved.
Outdated
let num_tasks = task_graph.get_num_tasks();
if num_tasks == 0 {
return Err(InternalError::TaskGraphEmpty);
}
let expected_inputs = task_graph.get_task_graph_input_indices().len();
let actual_inputs = inputs.len();
if expected_inputs != actual_inputs {
return Err(InternalError::TaskGraphInputSizeMismatch {
expected: expected_inputs,
actual: actual_inputs,
});
}
Comment thread
sitaowang1998 marked this conversation as resolved.
Outdated
Ok(Self { task_graph, inputs })
}

/// # Returns
///
/// A reference to the validated task graph.
#[must_use]
pub const fn task_graph(&self) -> &TaskGraph {
&self.task_graph
}

/// # Returns
///
/// A reference to the validated job inputs.
#[must_use]
pub fn inputs(&self) -> &[TaskInput] {
&self.inputs
}

/// Consumes the wrapper and returns the owned task graph and job inputs.
///
/// # Returns
///
/// A tuple of `(task_graph, inputs)`.
#[must_use]
pub fn into_parts(self) -> (TaskGraph, Vec<TaskInput>) {
(self.task_graph, self.inputs)
}
}

#[cfg(test)]
mod tests {
use spider_core::{
task::{
DataTypeDescriptor,
ExecutionPolicy,
TaskDescriptor,
TaskGraph as SubmittedTaskGraph,
TdlContext,
ValueTypeDescriptor,
},
types::io::TaskInput,
};

use super::{super::error::InternalError, *};

fn create_single_input_task_graph() -> SubmittedTaskGraph {
let bytes_type = DataTypeDescriptor::Value(ValueTypeDescriptor::bytes());
let mut graph =
SubmittedTaskGraph::new(None, None).expect("task graph creation should succeed");
graph
.insert_task(TaskDescriptor {
tdl_context: TdlContext {
package: "test_pkg".to_owned(),
task_func: "test_fn".to_owned(),
},
execution_policy: Some(ExecutionPolicy::default()),
inputs: vec![bytes_type],
outputs: vec![],
input_sources: None,
})
.expect("task insertion should succeed");
graph
}

#[test]
fn valid_job_submission_succeeds() {
let graph = create_single_input_task_graph();
let inputs = vec![TaskInput::ValuePayload(vec![1u8; 4])];
let result = ValidatedJobSubmission::validate(graph, inputs);
assert!(result.is_ok(), "valid submission should succeed");
}

#[test]
fn empty_task_graph_fails() {
let graph =
SubmittedTaskGraph::new(None, None).expect("task graph creation should succeed");
let inputs = vec![];
let result = ValidatedJobSubmission::validate(graph, inputs);
assert!(
matches!(result, Err(InternalError::TaskGraphEmpty)),
"empty task graph should return EmptyTaskGraph"
);
}

#[test]
fn mismatched_input_count_fails() {
let graph = create_single_input_task_graph();
let inputs = vec![];
let result = ValidatedJobSubmission::validate(graph, inputs);
assert!(
matches!(
result,
Err(InternalError::TaskGraphInputSizeMismatch {
expected: 1,
actual: 0
})
),
"mismatched input count should return TaskGraphInputSizeMismatch"
);
}

#[test]
fn into_parts_returns_owned_components() {
let graph = create_single_input_task_graph();
let inputs = vec![TaskInput::ValuePayload(vec![1u8; 4])];
let submission =
ValidatedJobSubmission::validate(graph, inputs).expect("submission should be valid");
let (graph, inputs) = submission.into_parts();
assert_eq!(graph.get_num_tasks(), 1, "task graph should have 1 task");
assert_eq!(inputs.len(), 1, "should have 1 input");
}
}
54 changes: 28 additions & 26 deletions components/spider-storage/src/cache/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,7 @@ use std::{
};

use spider_core::{
task::{
Task,
TaskGraph as SubmittedTaskGraph,
TaskIndex,
TaskState,
TdlContext,
TerminationTaskDescriptor,
TimeoutPolicy,
},
task::{Task, TaskIndex, TaskState, TdlContext, TerminationTaskDescriptor, TimeoutPolicy},
types::{
id::TaskInstanceId,
io::{ExecutionContext, TaskInput, TaskOutput},
Expand All @@ -24,6 +16,7 @@ use tokio::sync::RwLock;
use crate::cache::{
error::{CacheError, InternalError, StaleStateError},
io::{InputReader, OutputReader, OutputWriter, ValuePayload},
job_submission::ValidatedJobSubmission,
sync::{Reader, SharedRw, Writer},
};

Expand All @@ -38,7 +31,7 @@ pub struct TaskGraph {
impl TaskGraph {
/// Factory function.
///
/// Creates a new task graph from a submitted task graph and the input task inputs.
/// Creates a new task graph from a validated job submission.
///
/// # Returns
///
Expand All @@ -51,28 +44,18 @@ impl TaskGraph {
/// * [`InternalError::TaskGraphCorrupted`] if:
/// * Any dataflow deps' index is out-of-range.
/// * Any task index is out-of-range.
/// * [`InternalError::TaskGraphInputsSizeMismatch`] if the number of provided inputs does not
/// match the task graph’s expected number of inputs.
/// * Forwards [`SharedTaskControlBlock::create`]'s return values on failure.
///
/// # Panics
///
/// Panics if the internal TCB buffer is corrupted.
pub async fn create(
submitted_task_graph: &SubmittedTaskGraph,
inputs: Vec<TaskInput>,
) -> Result<Self, InternalError> {
pub async fn create(job_submission: ValidatedJobSubmission) -> Result<Self, InternalError> {
let (submitted_task_graph, inputs) = job_submission.into_parts();
let dataflow_dep_buffer: Vec<SharedRw<ValuePayload>> = (0..submitted_task_graph
.get_num_dataflow_deps())
.map(|_| SharedRw::new(RwLock::new(ValuePayload::default())))
.collect();
let task_graph_input_indices = submitted_task_graph.get_task_graph_input_indices();
if inputs.len() != task_graph_input_indices.len() {
return Err(InternalError::TaskGraphInputsSizeMismatch(
task_graph_input_indices.len(),
inputs.len(),
));
}
for (deps_index, input) in task_graph_input_indices.into_iter().zip(inputs) {
let dataflow_dep = dataflow_dep_buffer.get(deps_index).ok_or_else(|| {
InternalError::TaskGraphCorrupted(
Expand Down Expand Up @@ -949,6 +932,7 @@ mod tests {
};

use super::*;
use crate::cache::job_submission::ValidatedJobSubmission;

/// # Returns
///
Expand Down Expand Up @@ -1092,7 +1076,9 @@ mod tests {
let inputs: Vec<TaskInput> = (0..num_inputs)
.map(|_| TaskInput::ValuePayload(vec![0u8; 4]))
.collect();
TaskGraph::create(&submitted, inputs)
let job_submission = ValidatedJobSubmission::validate(submitted, inputs)
.expect("job submission should be valid");
TaskGraph::create(job_submission)
.await
.expect("cache task graph creation should succeed")
}
Expand All @@ -1107,7 +1093,7 @@ mod tests {
max_num_instances: u32,
max_num_retry: u32,
) -> SharedTerminationTaskControlBlock {
let submitted = SubmittedTaskGraph::new(
let mut submitted = SubmittedTaskGraph::new(
Some(TerminationTaskDescriptor {
tdl_context: TdlContext {
package: "test_pkg".to_owned(),
Expand All @@ -1122,7 +1108,21 @@ mod tests {
None,
)
.expect("task graph with commit task should be created");
let task_graph = TaskGraph::create(&submitted, vec![])
submitted
.insert_task(TaskDescriptor {
tdl_context: TdlContext {
package: "test_pkg".to_owned(),
task_func: "dummy_fn".to_owned(),
},
execution_policy: Some(ExecutionPolicy::default()),
inputs: vec![],
outputs: vec![],
input_sources: None,
})
.expect("task insertion should succeed");
let job_submission = ValidatedJobSubmission::validate(submitted, vec![])
.expect("job submission should be valid");
let task_graph = TaskGraph::create(job_submission)
.await
.expect("cache task graph creation should succeed");
task_graph
Expand Down Expand Up @@ -1237,7 +1237,9 @@ mod tests {
TaskInput::ValuePayload(input_a),
TaskInput::ValuePayload(input_b),
];
TaskGraph::create(&submitted, inputs)
let job_submission = ValidatedJobSubmission::validate(submitted, inputs)
.expect("job submission should be valid");
TaskGraph::create(job_submission)
.await
.expect("cache task graph creation should succeed")
}
Expand Down
Loading
Loading