diff --git a/Cargo.lock b/Cargo.lock index 516efac4..63728323 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -703,9 +703,9 @@ checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "js-sys" -version = "0.3.97" +version = "0.3.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1840c94c045fbcf8ba2812c95db44499f7c64910a912551aaaa541decebcacf" +checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" dependencies = [ "cfg-if", "futures-util", @@ -1423,6 +1423,7 @@ dependencies = [ "thiserror", "tokio", "tokio-util", + "tracing", "uuid", ] @@ -1839,9 +1840,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.52.2" +version = "1.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "110a78583f19d5cdb2c5ccf321d1290344e71313c6c37d43520d386027d18386" +checksum = "8fc7f01b389ac15039e4dc9531aa973a135d7a4135281b12d7c1bc79fd57fffe" dependencies = [ "bytes", "libc", @@ -2039,9 +2040,9 @@ checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" [[package]] name = "wasm-bindgen" -version = "0.2.120" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df52b6d9b87e0c74c9edfa1eb2d9bf85e5d63515474513aa50fa181b3c4f5db1" +checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" dependencies = [ "cfg-if", "once_cell", @@ -2052,9 +2053,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.120" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b1041f495fb322e64aca85f5756b2172e35cd459376e67f2a6c9dffcedb103" +checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2062,9 +2063,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.120" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dcd0ff20416988a18ac686d4d4d0f6aae9ebf08a389ff5d29012b05af2a1b41" +checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" dependencies = [ "bumpalo", "proc-macro2", @@ -2075,9 +2076,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.120" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49757b3c82ebf16c57d69365a142940b384176c24df52a087fb748e2085359ea" +checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" dependencies = [ "unicode-ident", ] diff --git a/components/spider-core/src/types/io.rs b/components/spider-core/src/types/io.rs index 00c06057..91b4b425 100644 --- a/components/spider-core/src/types/io.rs +++ b/components/spider-core/src/types/io.rs @@ -22,3 +22,39 @@ pub struct ExecutionContext { pub timeout_policy: TimeoutPolicy, pub inputs: Vec, } + +/// Serialized job inputs, each element a msgpack-serialized [`TaskInput`]. +pub type SerializedJobInputs = Vec>; + +/// Deserializes msgpack-serialized job inputs. +/// +/// # Returns +/// +/// The deserialized job inputs on success. +/// +/// # Errors +/// +/// Returns `rmp_serde::decode::Error` if any input fails to deserialize. +pub fn deserialize_job_inputs( + inputs: &SerializedJobInputs, +) -> Result, rmp_serde::decode::Error> { + inputs + .iter() + .map(|bytes| rmp_serde::from_slice(bytes)) + .collect() +} + +/// Serializes job inputs to msgpack. +/// +/// # Returns +/// +/// The serialized job inputs on success. +/// +/// # Errors +/// +/// Returns `rmp_serde::encode::Error` if any input fails to serialize. +pub fn serialize_job_inputs( + inputs: &[TaskInput], +) -> Result { + inputs.iter().map(rmp_serde::to_vec).collect() +} diff --git a/components/spider-storage/Cargo.toml b/components/spider-storage/Cargo.toml index 2a661e89..3fdb5a18 100644 --- a/components/spider-storage/Cargo.toml +++ b/components/spider-storage/Cargo.toml @@ -31,6 +31,7 @@ tokio = { version = "1.50.0", features = [ "sync", "time" ] } +tracing = { version = "0.1.44", features = ["attributes"] } uuid = { version = "1.19.0", features = ["serde"] } [dev-dependencies] diff --git a/components/spider-storage/src/cache/job.rs b/components/spider-storage/src/cache/job.rs index 5c575e8e..9d1a1f1c 100644 --- a/components/spider-storage/src/cache/job.rs +++ b/components/spider-storage/src/cache/job.rs @@ -100,6 +100,42 @@ impl< self.inner.id } + /// # Returns + /// + /// The current job state. + pub async fn state(&self) -> JobState { + self.inner.job_execution_state.read_state().await + } + + /// Gets the outputs of the job from the in-memory task graph. + /// + /// # Returns + /// + /// The outputs of the job on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`InternalError::UnexpectedJobState`] if the job is not in [`JobState::Succeeded`]. + /// * [`InternalError::TaskInputNotReady`] if any output has no value. + pub async fn get_outputs(&self) -> Result, CacheError> { + let jcb = &self.inner; + let job = jcb.job_execution_state.read_succeeded().await?; + let mut outputs = Vec::new(); + for output_reader in job.task_graph.get_outputs() { + let payload = output_reader + .read() + .await + .as_ref() + .ok_or(InternalError::TaskInputNotReady)? + .clone(); + outputs.push(payload); + } + drop(job); + Ok(outputs) + } + /// Starts the job. /// /// Any tasks in [`TaskState::Ready`] will be enqueued to the ready queue on success. @@ -722,6 +758,13 @@ struct JobExecutionStateHandle< impl JobExecutionStateHandle { + /// # Returns + /// + /// The current job state. + async fn read_state(&self) -> JobState { + self.inner.read().await.state + } + /// # Returns /// /// A reader guard of the underlying job execution state on success. @@ -802,6 +845,22 @@ impl Result>, CacheError> { + self.validate_and_read(JobExecutionState::ensure_succeeded) + .await + } + /// # Returns /// /// A writer guard of the underlying job execution state on success. @@ -979,7 +1038,7 @@ impl< /// Returns an error if: /// /// * [`InternalError::UnexpectedJobState`] if the job is in an unexpected state. - /// * [`StaleStateError::JobNoLongerCommitReady`] if the job is no longer cleanup-ready. + /// * [`StaleStateError::JobNoLongerCleanupReady`] if the job is no longer cleanup-ready. fn ensure_cleanup_ready(&self) -> Result<(), CacheError> { if !matches!(self.state, JobState::CleanupReady) { if self.state.is_terminal() { @@ -994,6 +1053,24 @@ impl< Ok(()) } + /// Ensures that the job is currently in the [`JobState::Succeeded`] state. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`InternalError::UnexpectedJobState`] if the job is in an unexpected state. + fn ensure_succeeded(&self) -> Result<(), CacheError> { + if !matches!(self.state, JobState::Succeeded) { + return Err(UnexpectedJobState { + current: self.state, + expected: JobState::Succeeded, + } + .into()); + } + Ok(()) + } + /// Ensures that the job is currently in a cancellable state. /// /// # Errors @@ -1002,7 +1079,7 @@ impl< /// /// * [`StaleStateError::JobCancellationAlreadyRequested`] if job cancellation has already been /// requested. - /// * [`StaleStateError::JobAlreadyCancelled`] if the job is already been cancelled. + /// * [`StaleStateError::JobAlreadyCancelled`] if the job has already been cancelled. /// * [`StaleStateError::JobAlreadyTerminated`] if the job has already terminated. fn ensure_cancellable(&self) -> Result<(), CacheError> { if matches!(self.state, JobState::CleanupReady) { diff --git a/components/spider-storage/src/state.rs b/components/spider-storage/src/state.rs index 90c28592..e9d905ac 100644 --- a/components/spider-storage/src/state.rs +++ b/components/spider-storage/src/state.rs @@ -1,5 +1,10 @@ pub mod error; pub mod job_cache; +pub mod service; pub use error::StorageServerError; pub use job_cache::JobCache; +pub use service::ServiceState; + +#[cfg(test)] +mod test_mocks; diff --git a/components/spider-storage/src/state/error.rs b/components/spider-storage/src/state/error.rs index 2df78756..2c5e8354 100644 --- a/components/spider-storage/src/state/error.rs +++ b/components/spider-storage/src/state/error.rs @@ -1,6 +1,6 @@ -use spider_core::types::id::JobId; +use spider_core::{task, types::id::JobId}; -use crate::cache::error::CacheError; +use crate::{cache::error::CacheError, db::DbError}; /// Errors that can occur during storage server operations. #[derive(thiserror::Error, Debug)] @@ -8,15 +8,24 @@ pub enum StorageServerError { #[error(transparent)] Cache(#[from] CacheError), + #[error(transparent)] + Db(#[from] DbError), + + #[error(transparent)] + Task(#[from] task::Error), + #[error("stale session")] StaleSession, #[error("server is shutting down: {0}")] Stopping(String), - #[error("bad request: {0}")] - BadRequest(String), + #[error("job not found in cache: {0:?}")] + JobNotFound(JobId), #[error("job already exists: {0:?}")] JobAlreadyExists(JobId), + + #[error("bad request: {0}")] + BadRequest(String), } diff --git a/components/spider-storage/src/state/job_cache.rs b/components/spider-storage/src/state/job_cache.rs index 6ad3c7ce..ff5883a3 100644 --- a/components/spider-storage/src/state/job_cache.rs +++ b/components/spider-storage/src/state/job_cache.rs @@ -5,7 +5,7 @@ use crate::{ cache::job::SharedJobControlBlock, db::InternalJobOrchestration, ready_queue::ReadyQueueSender, - state::error::StorageServerError, + state::StorageServerError, task_instance_pool::TaskInstancePoolConnector, }; @@ -131,7 +131,6 @@ mod tests { use std::sync::Arc; use spider_core::{ - job::JobState, task::{ DataTypeDescriptor, ExecutionPolicy, @@ -140,10 +139,7 @@ mod tests { TdlContext, ValueTypeDescriptor, }, - types::{ - id::JobId, - io::{TaskInput, TaskOutput}, - }, + types::{id::JobId, io::TaskInput}, }; use super::*; @@ -152,111 +148,11 @@ mod tests { error::InternalError, job::SharedJobControlBlock, job_submission::ValidatedJobSubmission, - task::{SharedTaskControlBlock, SharedTerminationTaskControlBlock}, }, - db::DbError, ready_queue::ReadyQueueSender, - task_instance_pool::{TaskInstanceMetadata, TaskInstancePoolConnector}, + state::test_mocks::{MockDbConnector, MockReadyQueueSender, MockTaskInstancePoolConnector}, }; - /// A mock ready queue sender for testing. - #[derive(Clone, Default)] - struct MockReadyQueueSender; - - #[async_trait::async_trait] - impl ReadyQueueSender for MockReadyQueueSender { - async fn send_task_ready( - &self, - _rg_id: spider_core::types::id::ResourceGroupId, - _job_id: JobId, - _task_indices: Vec, - ) -> Result<(), InternalError> { - Ok(()) - } - - async fn send_commit_ready( - &self, - _rg_id: spider_core::types::id::ResourceGroupId, - _job_id: JobId, - ) -> Result<(), InternalError> { - Ok(()) - } - - async fn send_cleanup_ready( - &self, - _rg_id: spider_core::types::id::ResourceGroupId, - _job_id: JobId, - ) -> Result<(), InternalError> { - Ok(()) - } - } - - /// A mock DB connector for testing. - #[derive(Clone, Default)] - struct MockDbConnector; - - #[async_trait::async_trait] - impl InternalJobOrchestration for MockDbConnector { - async fn start(&self, _job_id: JobId) -> Result<(), DbError> { - Ok(()) - } - - async fn set_state(&self, _job_id: JobId, _state: JobState) -> Result<(), DbError> { - Ok(()) - } - - async fn commit_outputs( - &self, - _job_id: JobId, - _outputs: Vec, - _has_commit_task: bool, - ) -> Result<(), DbError> { - Ok(()) - } - - async fn cancel(&self, _job_id: JobId, _has_cleanup_task: bool) -> Result<(), DbError> { - Ok(()) - } - - async fn fail(&self, _job_id: JobId, _error_message: String) -> Result<(), DbError> { - Ok(()) - } - - async fn delete_expired_terminated_jobs( - &self, - _expire_after_sec: u64, - ) -> Result, DbError> { - Ok(Vec::new()) - } - } - - /// A mock task instance pool connector for testing. - #[derive(Clone, Default)] - struct MockTaskInstancePoolConnector; - - #[async_trait::async_trait] - impl TaskInstancePoolConnector for MockTaskInstancePoolConnector { - fn get_next_available_task_instance_id(&self) -> spider_core::types::id::TaskInstanceId { - 1 - } - - async fn register_task_instance( - &self, - _tcb: SharedTaskControlBlock, - _registration: TaskInstanceMetadata, - ) -> Result<(), InternalError> { - Ok(()) - } - - async fn register_termination_task_instance( - &self, - _termination_tcb: SharedTerminationTaskControlBlock, - _registration: TaskInstanceMetadata, - ) -> Result<(), InternalError> { - Ok(()) - } - } - async fn create_test_jcb( job_id: JobId, ) -> SharedJobControlBlock @@ -285,7 +181,7 @@ mod tests { spider_core::types::id::ResourceGroupId::new(), job_submission, MockReadyQueueSender, - MockDbConnector, + MockDbConnector::default(), MockTaskInstancePoolConnector, ) .await @@ -465,7 +361,7 @@ mod tests { spider_core::types::id::ResourceGroupId::new(), job_submission, sender, - MockDbConnector, + MockDbConnector::default(), MockTaskInstancePoolConnector, ) .await diff --git a/components/spider-storage/src/state/service.rs b/components/spider-storage/src/state/service.rs new file mode 100644 index 00000000..0fe856dd --- /dev/null +++ b/components/spider-storage/src/state/service.rs @@ -0,0 +1,1040 @@ +use std::sync::Arc; + +use spider_core::{ + job::JobState, + task::{TaskGraph, TaskIndex}, + types::{ + id::{ExecutionManagerId, JobId, ResourceGroupId, SessionId, TaskInstanceId}, + io::{ExecutionContext, SerializedJobInputs, TaskOutput, deserialize_job_inputs}, + }, +}; +use tracing::{debug, instrument}; + +use crate::{ + cache::{ + TaskId, + error::CacheError, + job::SharedJobControlBlock, + job_submission::ValidatedJobSubmission, + }, + db::DbStorage, + ready_queue::{ReadyQueueReceiverHandle, ReadyQueueSender}, + state::{JobCache, StorageServerError}, + task_instance_pool::TaskInstancePoolConnector, +}; + +/// Inner data for [`ServiceState`], holding all storage services. +/// +/// The job cache is stored directly (not in an Arc) so that cloning the outer `ServiceState` +/// only clones a single `Arc`. +struct ServiceStateInner< + ReadyQueueSenderType: ReadyQueueSender, + DbConnectorType: DbStorage, + TaskInstancePoolConnectorType: TaskInstancePoolConnector, +> { + db: DbConnectorType, + session_id: SessionId, + job_cache: JobCache, + ready_queue_sender: ReadyQueueSenderType, + _ready_queue_receiver: ReadyQueueReceiverHandle, + task_instance_pool_connector: TaskInstancePoolConnectorType, +} + +/// Per-request service state providing access to the storage layer. +/// +/// Holds a DB connector, session ID, job cache, and ready queue handles. Request handlers call +/// methods on `ServiceState` directly. +/// +/// Internally wraps a single [`Arc`] around [`ServiceStateInner`] so that cloning is cheap (one +/// Arc clone instead of cloning each field). +/// +/// # Type Parameters +/// +/// * `ReadyQueueSenderType` - The type of the ready queue sender. +/// * `DbConnectorType` - The type of the DB-layer connector. +/// * `TaskInstancePoolConnectorType` - The type of the task instance pool connector. +#[derive(Clone)] +pub struct ServiceState< + ReadyQueueSenderType: ReadyQueueSender, + DbConnectorType: DbStorage, + TaskInstancePoolConnectorType: TaskInstancePoolConnector, +> { + inner: Arc< + ServiceStateInner, + >, +} + +impl< + ReadyQueueSenderType: ReadyQueueSender, + DbConnectorType: DbStorage, + TaskInstancePoolConnectorType: TaskInstancePoolConnector, +> ServiceState +{ + /// Factory function. + /// + /// # Returns + /// + /// A newly created [`ServiceState`] from its constituent parts. + pub fn new( + db: DbConnectorType, + session_id: SessionId, + job_cache: JobCache, + ready_queue_sender: ReadyQueueSenderType, + ready_queue_receiver: ReadyQueueReceiverHandle, + task_instance_pool_connector: TaskInstancePoolConnectorType, + ) -> Self { + Self { + inner: Arc::new(ServiceStateInner { + db, + session_id, + job_cache, + ready_queue_sender, + _ready_queue_receiver: ready_queue_receiver, + task_instance_pool_connector, + }), + } + } + + /// Registers a job in the database and inserts its control block into the cache. + /// + /// # Returns + /// + /// The ID of the registered job on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * Forwards [`TaskGraph::from_json`]'s return values on failure. + /// * Forwards [`deserialize_job_inputs`]'s return values on failure. + /// * Forwards [`ValidatedJobSubmission::create`]'s return values on failure. + /// * Forwards [`ExternalJobOrchestration::register`]'s return values on failure. + /// * Forwards [`SharedJobControlBlock::create`]'s return values on failure. + /// * Forwards [`JobCache::insert`]'s return values on failure. + #[instrument( + skip(self, serialized_task_graph, serialized_job_inputs), + fields(job_id) + )] + pub async fn register_job( + &self, + resource_group_id: ResourceGroupId, + serialized_task_graph: String, + serialized_job_inputs: SerializedJobInputs, + ) -> Result { + let task_graph = + TaskGraph::from_json(&serialized_task_graph).map_err(StorageServerError::Task)?; + let inputs = deserialize_job_inputs(&serialized_job_inputs) + .map_err(|e| StorageServerError::Task(e.into()))?; + let job_submission = + ValidatedJobSubmission::create(task_graph, inputs).map_err(CacheError::from)?; + + let job_id = self + .inner + .db + .register(resource_group_id, &job_submission) + .await?; + + tracing::Span::current().record("job_id", tracing::field::debug(&job_id)); + + let jcb = SharedJobControlBlock::create( + job_id, + resource_group_id, + job_submission, + self.inner.ready_queue_sender.clone(), + self.inner.db.clone(), + self.inner.task_instance_pool_connector.clone(), + ) + .await?; + + self.inner.job_cache.insert(jcb)?; + debug!("Inserted JCB into job cache."); + + Ok(job_id) + } + + /// Starts a job for execution. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`StorageServerError::JobNotFound`] if the job is not in the cache. + /// * Forwards [`SharedJobControlBlock::start`]'s return values on failure. + #[instrument(skip(self), fields(job_id = ?job_id))] + pub async fn start_job(&self, job_id: JobId) -> Result<(), StorageServerError> { + let jcb = self + .inner + .job_cache + .get(job_id) + .ok_or(StorageServerError::JobNotFound(job_id))?; + debug!("JCB found in cache, starting job."); + jcb.start().await?; + Ok(()) + } + + /// Cancels a job. + /// + /// # Returns + /// + /// The job state after the cancellation operation on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`StorageServerError::JobNotFound`] if the job is not in the cache. + /// * Forwards [`SharedJobControlBlock::cancel`]'s return values on failure. + #[instrument(skip(self), fields(job_id = ?job_id))] + pub async fn cancel_job(&self, job_id: JobId) -> Result { + let jcb = self + .inner + .job_cache + .get(job_id) + .ok_or(StorageServerError::JobNotFound(job_id))?; + debug!("JCB found in cache, cancelling job."); + let state = jcb.cancel().await?; + Ok(state) + } + + /// Gets the state of a job. + /// + /// Checks the job cache first; if the JCB is present, returns its in-memory state. Otherwise + /// falls back to the database. + /// + /// # Returns + /// + /// The state of the job on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * Forwards [`ExternalJobOrchestration::get_state`]'s return values on failure (DB fallback). + #[instrument(skip(self), fields(job_id = ?job_id))] + pub async fn get_job_state(&self, job_id: JobId) -> Result { + if let Some(jcb) = self.inner.job_cache.get(job_id) { + debug!("JCB found in cache, returning in-memory state."); + return Ok(jcb.state().await); + } + debug!("JCB not in cache, falling back to database."); + Ok(self.inner.db.get_state(job_id).await?) + } + + /// Gets the outputs of a job. + /// + /// Checks the job cache first; if the JCB is present, returns its in-memory outputs. Otherwise + /// falls back to the database. + /// + /// # Returns + /// + /// The outputs of the job on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * Forwards [`SharedJobControlBlock::get_outputs`]'s return values on failure (cache path). + /// * Forwards [`ExternalJobOrchestration::get_outputs`]'s return values on failure (DB + /// fallback). + #[instrument(skip(self), fields(job_id = ?job_id))] + pub async fn get_job_outputs( + &self, + job_id: JobId, + ) -> Result, StorageServerError> { + if let Some(jcb) = self.inner.job_cache.get(job_id) { + debug!("JCB found in cache, returning in-memory outputs."); + return Ok(jcb.get_outputs().await?); + } + debug!("JCB not in cache, falling back to database."); + Ok(self.inner.db.get_outputs(job_id).await?) + } + + /// Gets the error message of a job from the database. + /// + /// # Returns + /// + /// The error message of the job on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * Forwards [`ExternalJobOrchestration::get_error`]'s return values on failure. + #[instrument(skip(self), fields(job_id = ?job_id))] + pub async fn get_job_error(&self, job_id: JobId) -> Result { + Ok(self.inner.db.get_error(job_id).await?) + } + + /// Creates a task instance for the given task and registers it in the task instance pool. + /// + /// # Returns + /// + /// The execution context for the created task instance on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`StorageServerError::StaleSession`] if the session has changed. + /// * [`StorageServerError::JobNotFound`] if the job is not in the cache. + /// * Forwards [`SharedJobControlBlock::create_task_instance`]'s return values on failure. + #[instrument(skip(self, session_id), fields(job_id = ?job_id, task_id = ?task_id))] + pub async fn create_task_instance( + &self, + session_id: SessionId, + job_id: JobId, + task_id: TaskId, + execution_manager_id: ExecutionManagerId, + ) -> Result { + self.validate_session(session_id)?; + let jcb = self + .inner + .job_cache + .get(job_id) + .ok_or(StorageServerError::JobNotFound(job_id))?; + debug!("JCB found in cache, creating task instance."); + Ok(jcb + .create_task_instance(task_id, execution_manager_id) + .await?) + } + + /// Marks a task instance as succeeded. + /// + /// If all tasks have succeeded, commits the job outputs and transitions the job state. + /// + /// # Returns + /// + /// The current job state after the operation on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`StorageServerError::StaleSession`] if the session has changed. + /// * [`StorageServerError::JobNotFound`] if the job is not in the cache. + /// * Forwards [`SharedJobControlBlock::succeed_task_instance`]'s return values on failure. + #[instrument( + skip(self, session_id, task_outputs), + fields(job_id = ?job_id, task_instance_id = ?task_instance_id) + )] + pub async fn succeed_task_instance( + &self, + session_id: SessionId, + job_id: JobId, + task_instance_id: TaskInstanceId, + task_index: TaskIndex, + task_outputs: Vec, + ) -> Result { + self.validate_session(session_id)?; + let jcb = self + .inner + .job_cache + .get(job_id) + .ok_or(StorageServerError::JobNotFound(job_id))?; + debug!("JCB found in cache, succeeding task instance."); + let state = jcb + .succeed_task_instance(task_instance_id, task_index, task_outputs) + .await?; + Ok(state) + } + + /// Marks a task instance as failed. + /// + /// # Returns + /// + /// The current job state after the operation on success. + /// + /// # Errors + /// + /// Returns an error if: + /// + /// * [`StorageServerError::StaleSession`] if the session has changed. + /// * [`StorageServerError::JobNotFound`] if the job is not in the cache. + /// * Forwards [`SharedJobControlBlock::fail_task_instance`]'s return values on failure. + #[instrument( + skip(self, session_id, error), + fields(job_id = ?job_id, task_instance_id = ?task_instance_id, task_id = ?task_id) + )] + pub async fn fail_task_instance( + &self, + session_id: SessionId, + job_id: JobId, + task_instance_id: TaskInstanceId, + task_id: TaskId, + error: String, + ) -> Result { + self.validate_session(session_id)?; + let jcb = self + .inner + .job_cache + .get(job_id) + .ok_or(StorageServerError::JobNotFound(job_id))?; + debug!("JCB found in cache, failing task instance."); + let state = jcb + .fail_task_instance(task_instance_id, task_id, error) + .await?; + Ok(state) + } + + /// Validates that the given `session_id` matches the session ID captured at service creation + /// time. + /// + /// # Errors + /// + /// Returns [`StorageServerError::StaleSession`] if the session IDs don't match. + fn validate_session(&self, session_id: SessionId) -> Result<(), StorageServerError> { + if session_id != self.inner.session_id { + debug!("Session ID mismatch."); + return Err(StorageServerError::StaleSession); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use spider_core::{ + job::JobState, + task::{ + DataTypeDescriptor, + ExecutionPolicy, + TaskDescriptor, + TaskGraph as SubmittedTaskGraph, + TdlContext, + ValueTypeDescriptor, + }, + types::{ + id::{ExecutionManagerId, JobId, ResourceGroupId}, + io::{TaskInput, TaskOutput}, + }, + }; + + use super::*; + use crate::{ + cache::{job::SharedJobControlBlock, job_submission::ValidatedJobSubmission}, + state::{ + StorageServerError, + test_mocks::{MockDbConnector, MockReadyQueueSender, MockTaskInstancePoolConnector}, + }, + }; + + type TestServiceState = + ServiceState; + + const TEST_SESSION_ID: SessionId = 0; + + fn create_test_service() -> TestServiceState { + create_test_service_with_db(MockDbConnector::default()) + } + + fn create_test_service_with_db(db: MockDbConnector) -> TestServiceState { + create_test_service_with_db_and_session(db, TEST_SESSION_ID) + } + + fn create_test_service_with_db_and_session( + db: MockDbConnector, + session_id: SessionId, + ) -> TestServiceState { + use crate::ready_queue::{ReadyQueueConfig, create_ready_queue}; + let (_sender, receiver) = + create_ready_queue(ReadyQueueConfig::default()).expect("ready queue creation"); + TestServiceState::new( + db, + session_id, + JobCache::new(), + MockReadyQueueSender, + receiver, + MockTaskInstancePoolConnector, + ) + } + + fn create_test_task_graph() -> SubmittedTaskGraph { + let bytes_type = DataTypeDescriptor::Value(ValueTypeDescriptor::bytes()); + let mut task_graph = + SubmittedTaskGraph::new(None, None).expect("task graph creation should succeed"); + task_graph + .insert_task(TaskDescriptor { + tdl_context: TdlContext { + package: "test_pkg".to_owned(), + task_func: "test_fn".to_owned(), + }, + execution_policy: Some(ExecutionPolicy::default()), + inputs: vec![bytes_type.clone()], + outputs: vec![bytes_type], + input_sources: None, + }) + .expect("task insertion should succeed"); + task_graph + } + + fn create_serialized_test_job_submission() -> (String, SerializedJobInputs) { + let task_graph = create_test_task_graph() + .to_json() + .expect("task graph serialization should succeed"); + let inputs = vec![ + rmp_serde::to_vec(&TaskInput::ValuePayload(vec![0u8; 4])) + .expect("input serialization should succeed"), + ]; + (task_graph, inputs) + } + + async fn create_test_jcb( + job_id: JobId, + ) -> SharedJobControlBlock + { + let task_graph = create_test_task_graph(); + let job_submission = + ValidatedJobSubmission::create(task_graph, vec![TaskInput::ValuePayload(vec![0u8; 4])]) + .expect("job submission should be valid"); + + SharedJobControlBlock::create( + job_id, + ResourceGroupId::new(), + job_submission, + MockReadyQueueSender, + MockDbConnector::default(), + MockTaskInstancePoolConnector, + ) + .await + .expect("JCB creation should succeed") + } + + #[tokio::test] + async fn register_job_returns_job_id_and_inserts_into_cache() -> anyhow::Result<()> { + let service = create_test_service(); + let (serialized_task_graph, serialized_inputs) = create_serialized_test_job_submission(); + let job_id = service + .register_job( + ResourceGroupId::new(), + serialized_task_graph, + serialized_inputs, + ) + .await?; + assert_ne!(job_id, JobId::default(), "job ID should be assigned"); + assert!( + service.inner.job_cache.get(job_id).is_some(), + "JCB should be in cache after register_job" + ); + Ok(()) + } + + #[tokio::test] + async fn register_job_returns_error_on_invalid_task_graph() -> anyhow::Result<()> { + let service = create_test_service(); + let result = service + .register_job(ResourceGroupId::new(), "invalid json".to_owned(), vec![]) + .await; + assert!( + matches!(result, Err(StorageServerError::Task(_))), + "register_job should return Task error on invalid task graph JSON" + ); + Ok(()) + } + + #[tokio::test] + async fn register_job_returns_error_on_invalid_input_bytes() -> anyhow::Result<()> { + let service = create_test_service(); + let task_graph = create_test_task_graph() + .to_json() + .expect("task graph serialization should succeed"); + let result = service + .register_job( + ResourceGroupId::new(), + task_graph, + vec![vec![255u8; 1]], // invalid msgpack + ) + .await; + assert!( + matches!(result, Err(StorageServerError::Task(_))), + "register_job should return Task error on invalid msgpack input" + ); + Ok(()) + } + + #[tokio::test] + async fn register_job_returns_error_on_empty_task_graph() -> anyhow::Result<()> { + let service = create_test_service(); + let task_graph = spider_core::task::TaskGraph::new(None, None) + .expect("empty task graph creation should succeed") + .to_json() + .expect("task graph serialization should succeed"); + let result = service + .register_job(ResourceGroupId::new(), task_graph, vec![]) + .await; + assert!( + matches!(result, Err(StorageServerError::Cache(_))), + "register_job should return Cache error on empty task graph" + ); + Ok(()) + } + + #[tokio::test] + async fn start_job_starts_cached_job() -> anyhow::Result<()> { + let service = create_test_service(); + let (serialized_task_graph, serialized_inputs) = create_serialized_test_job_submission(); + let job_id = service + .register_job( + ResourceGroupId::new(), + serialized_task_graph, + serialized_inputs, + ) + .await?; + + service.start_job(job_id).await?; + + let state = service.get_job_state(job_id).await?; + assert_eq!(state, JobState::Running); + Ok(()) + } + + #[tokio::test] + async fn start_job_returns_job_not_found_when_not_in_cache() -> anyhow::Result<()> { + let service = create_test_service(); + let result = service.start_job(JobId::new()).await; + assert!( + matches!(result, Err(StorageServerError::JobNotFound(_))), + "start_job should return JobNotFound when job is not in cache" + ); + Ok(()) + } + + #[tokio::test] + async fn cancel_job_returns_job_not_found_when_not_in_cache() -> anyhow::Result<()> { + let service = create_test_service(); + let result = service.cancel_job(JobId::new()).await; + assert!( + matches!(result, Err(StorageServerError::JobNotFound(_))), + "cancel_job should return JobNotFound when job is not in cache" + ); + Ok(()) + } + + #[tokio::test] + async fn cancel_job_transitions_to_terminal_state() -> anyhow::Result<()> { + let service = create_test_service(); + let job_id = JobId::new(); + let jcb = create_test_jcb(job_id).await; + service.inner.job_cache.insert(jcb)?; + + let state = service.cancel_job(job_id).await?; + assert!( + state.is_terminal(), + "cancel should result in terminal state" + ); + assert!( + service.inner.job_cache.get(job_id).is_some(), + "JCB should remain in cache after terminal cancel" + ); + Ok(()) + } + + #[tokio::test] + async fn get_job_state_serves_from_cache_when_jcb_present() -> anyhow::Result<()> { + let service = create_test_service(); + let (serialized_task_graph, serialized_inputs) = create_serialized_test_job_submission(); + let job_id = service + .register_job( + ResourceGroupId::new(), + serialized_task_graph, + serialized_inputs, + ) + .await?; + + let state = service.get_job_state(job_id).await?; + assert_eq!(state, JobState::Ready); + Ok(()) + } + + #[tokio::test] + async fn get_job_state_falls_back_to_db_when_not_in_cache() -> anyhow::Result<()> { + let db = MockDbConnector::default(); + let job_id = JobId::new(); + db.states.insert(job_id, JobState::Failed); + + let service = create_test_service_with_db(db); + let state = service.get_job_state(job_id).await?; + assert_eq!(state, JobState::Failed); + Ok(()) + } + + #[tokio::test] + async fn get_job_state_returns_error_for_unknown_job() -> anyhow::Result<()> { + let service = create_test_service(); + let result = service.get_job_state(JobId::new()).await; + assert!(result.is_err(), "get_job_state should fail for unknown job"); + Ok(()) + } + + #[tokio::test] + async fn get_job_outputs_returns_outputs_from_db() -> anyhow::Result<()> { + let db = MockDbConnector::default(); + let job_id = JobId::new(); + let outputs: Vec = vec![vec![1, 2, 3]]; + db.outputs.insert(job_id, outputs.clone()); + + let service = create_test_service_with_db(db); + let result = service.get_job_outputs(job_id).await?; + assert_eq!(result, outputs); + Ok(()) + } + + #[tokio::test] + async fn get_job_outputs_returns_outputs_from_cache_when_jcb_present() -> anyhow::Result<()> { + let service = create_test_service(); + let (serialized_task_graph, serialized_inputs) = create_serialized_test_job_submission(); + let job_id = service + .register_job( + ResourceGroupId::new(), + serialized_task_graph, + serialized_inputs, + ) + .await?; + service.start_job(job_id).await?; + + let context = service + .create_task_instance( + TEST_SESSION_ID, + job_id, + TaskId::Index(0), + ExecutionManagerId::new(), + ) + .await?; + let outputs = vec![vec![0u8; 4]]; + service + .succeed_task_instance( + TEST_SESSION_ID, + job_id, + context.task_instance_id, + 0, + outputs.clone(), + ) + .await?; + + let result = service.get_job_outputs(job_id).await?; + assert_eq!(result, outputs); + Ok(()) + } + + #[tokio::test] + async fn get_job_outputs_returns_error_for_unknown_job() -> anyhow::Result<()> { + let service = create_test_service(); + let result = service.get_job_outputs(JobId::new()).await; + assert!( + result.is_err(), + "get_job_outputs should fail for unknown job" + ); + Ok(()) + } + + #[tokio::test] + async fn get_job_outputs_returns_error_when_job_not_succeeded() -> anyhow::Result<()> { + let service = create_test_service(); + let (serialized_task_graph, serialized_inputs) = create_serialized_test_job_submission(); + let job_id = service + .register_job( + ResourceGroupId::new(), + serialized_task_graph, + serialized_inputs, + ) + .await?; + // JCB is in cache but job is still Ready (not Succeeded). + let result = service.get_job_outputs(job_id).await; + assert!( + result.is_err(), + "get_job_outputs should fail when job is not Succeeded" + ); + Ok(()) + } + + #[tokio::test] + async fn get_job_error_returns_error_message_from_db() -> anyhow::Result<()> { + let db = MockDbConnector::default(); + let job_id = JobId::new(); + let error_msg = "something went wrong".to_owned(); + db.errors.insert(job_id, error_msg.clone()); + + let service = create_test_service_with_db(db); + let result = service.get_job_error(job_id).await?; + assert_eq!(result, error_msg); + Ok(()) + } + + #[tokio::test] + async fn get_job_error_returns_error_for_unknown_job() -> anyhow::Result<()> { + let service = create_test_service(); + let result = service.get_job_error(JobId::new()).await; + assert!(result.is_err(), "get_job_error should fail for unknown job"); + Ok(()) + } + + #[tokio::test] + async fn create_task_instance_returns_execution_context() -> anyhow::Result<()> { + let service = create_test_service(); + let (serialized_task_graph, serialized_inputs) = create_serialized_test_job_submission(); + let job_id = service + .register_job( + ResourceGroupId::new(), + serialized_task_graph, + serialized_inputs, + ) + .await?; + service.start_job(job_id).await?; + + let context = service + .create_task_instance( + TEST_SESSION_ID, + job_id, + TaskId::Index(0), + ExecutionManagerId::new(), + ) + .await?; + assert_eq!( + context.task_instance_id, 1, + "task instance ID should match mock pool counter" + ); + Ok(()) + } + + #[tokio::test] + async fn create_task_instance_returns_job_not_found_when_not_in_cache() -> anyhow::Result<()> { + let service = create_test_service(); + let result = service + .create_task_instance( + TEST_SESSION_ID, + JobId::new(), + TaskId::Index(0), + ExecutionManagerId::new(), + ) + .await; + assert!( + matches!(result, Err(StorageServerError::JobNotFound(_))), + "create_task_instance should return JobNotFound when job is not in cache" + ); + Ok(()) + } + + #[tokio::test] + async fn succeed_task_instance_transitions_job_to_succeeded() -> anyhow::Result<()> { + let service = create_test_service(); + let (serialized_task_graph, serialized_inputs) = create_serialized_test_job_submission(); + let job_id = service + .register_job( + ResourceGroupId::new(), + serialized_task_graph, + serialized_inputs, + ) + .await?; + service.start_job(job_id).await?; + + let context = service + .create_task_instance( + TEST_SESSION_ID, + job_id, + TaskId::Index(0), + ExecutionManagerId::new(), + ) + .await?; + let state = service + .succeed_task_instance( + TEST_SESSION_ID, + job_id, + context.task_instance_id, + 0, + vec![vec![0u8; 4]], + ) + .await?; + assert_eq!(state, JobState::Succeeded); + assert!( + service.inner.job_cache.get(job_id).is_some(), + "JCB should remain in cache after terminal succeed" + ); + Ok(()) + } + + #[tokio::test] + async fn succeed_task_instance_returns_job_not_found_when_not_in_cache() -> anyhow::Result<()> { + let service = create_test_service(); + let result = service + .succeed_task_instance(TEST_SESSION_ID, JobId::new(), 1, 0, vec![]) + .await; + assert!( + matches!(result, Err(StorageServerError::JobNotFound(_))), + "succeed_task_instance should return JobNotFound when job is not in cache" + ); + Ok(()) + } + + #[tokio::test] + async fn fail_task_instance_transitions_job_to_failed() -> anyhow::Result<()> { + let service = create_test_service(); + let (serialized_task_graph, serialized_inputs) = create_serialized_test_job_submission(); + let job_id = service + .register_job( + ResourceGroupId::new(), + serialized_task_graph, + serialized_inputs, + ) + .await?; + service.start_job(job_id).await?; + + let context = service + .create_task_instance( + TEST_SESSION_ID, + job_id, + TaskId::Index(0), + ExecutionManagerId::new(), + ) + .await?; + let state = service + .fail_task_instance( + TEST_SESSION_ID, + job_id, + context.task_instance_id, + TaskId::Index(0), + "test failure".to_owned(), + ) + .await?; + assert_eq!(state, JobState::Failed); + assert!( + service.inner.job_cache.get(job_id).is_some(), + "JCB should remain in cache after terminal fail" + ); + Ok(()) + } + + #[tokio::test] + async fn fail_task_instance_returns_job_not_found_when_not_in_cache() -> anyhow::Result<()> { + let service = create_test_service(); + let result = service + .fail_task_instance( + TEST_SESSION_ID, + JobId::new(), + 1, + TaskId::Index(0), + "error".to_owned(), + ) + .await; + assert!( + matches!(result, Err(StorageServerError::JobNotFound(_))), + "fail_task_instance should return JobNotFound when job is not in cache" + ); + Ok(()) + } + + #[tokio::test] + async fn task_instance_apis_return_stale_session_on_mismatch() -> anyhow::Result<()> { + // Create a service with a higher session ID to simulate a server restart. + let current_session_id: SessionId = 10; + let db = MockDbConnector::default(); + let service = create_test_service_with_db_and_session(db, current_session_id); + + // Register a job so the JCB is in cache. + let (serialized_task_graph, serialized_inputs) = create_serialized_test_job_submission(); + let job_id = service + .register_job( + ResourceGroupId::new(), + serialized_task_graph, + serialized_inputs, + ) + .await?; + + let stale_session_id = current_session_id - 1; + let result = service + .create_task_instance( + stale_session_id, + job_id, + TaskId::Index(0), + ExecutionManagerId::new(), + ) + .await; + assert!( + matches!(result, Err(StorageServerError::StaleSession)), + "create_task_instance should return StaleSession on session mismatch" + ); + Ok(()) + } + + #[tokio::test] + async fn succeed_task_instance_returns_stale_session_on_mismatch() -> anyhow::Result<()> { + let current_session_id: SessionId = 10; + let db = MockDbConnector::default(); + let service = create_test_service_with_db_and_session(db, current_session_id); + + let (serialized_task_graph, serialized_inputs) = create_serialized_test_job_submission(); + let job_id = service + .register_job( + ResourceGroupId::new(), + serialized_task_graph, + serialized_inputs, + ) + .await?; + service.start_job(job_id).await?; + + let context = service + .create_task_instance( + current_session_id, + job_id, + TaskId::Index(0), + ExecutionManagerId::new(), + ) + .await?; + + let stale_session_id = current_session_id - 1; + let result = service + .succeed_task_instance( + stale_session_id, + job_id, + context.task_instance_id, + 0, + vec![vec![0u8; 4]], + ) + .await; + assert!( + matches!(result, Err(StorageServerError::StaleSession)), + "succeed_task_instance should return StaleSession on session mismatch" + ); + Ok(()) + } + + #[tokio::test] + async fn fail_task_instance_returns_stale_session_on_mismatch() -> anyhow::Result<()> { + let current_session_id: SessionId = 10; + let db = MockDbConnector::default(); + let service = create_test_service_with_db_and_session(db, current_session_id); + + let (serialized_task_graph, serialized_inputs) = create_serialized_test_job_submission(); + let job_id = service + .register_job( + ResourceGroupId::new(), + serialized_task_graph, + serialized_inputs, + ) + .await?; + service.start_job(job_id).await?; + + let context = service + .create_task_instance( + current_session_id, + job_id, + TaskId::Index(0), + ExecutionManagerId::new(), + ) + .await?; + + let stale_session_id = current_session_id - 1; + let result = service + .fail_task_instance( + stale_session_id, + job_id, + context.task_instance_id, + TaskId::Index(0), + "error".to_owned(), + ) + .await; + assert!( + matches!(result, Err(StorageServerError::StaleSession)), + "fail_task_instance should return StaleSession on session mismatch" + ); + Ok(()) + } +} diff --git a/components/spider-storage/src/state/test_mocks.rs b/components/spider-storage/src/state/test_mocks.rs new file mode 100644 index 00000000..9c29bfdc --- /dev/null +++ b/components/spider-storage/src/state/test_mocks.rs @@ -0,0 +1,244 @@ +use std::{net::IpAddr, sync::Arc}; + +use dashmap::DashMap; +use spider_core::{ + job::JobState, + types::{ + id::{ExecutionManagerId, JobId, ResourceGroupId, SessionId, TaskInstanceId}, + io::TaskOutput, + }, +}; + +use crate::{ + cache::{ + error::InternalError, + job_submission::ValidatedJobSubmission, + task::{SharedTaskControlBlock, SharedTerminationTaskControlBlock}, + }, + db::{ + DbError, + DbStorage, + ExecutionManagerLivenessManagement, + ExternalJobOrchestration, + InternalJobOrchestration, + ResourceGroupManagement, + SessionManagement, + }, + ready_queue::ReadyQueueSender, + task_instance_pool::{TaskInstanceMetadata, TaskInstancePoolConnector}, +}; + +/// A mock ready queue sender for testing. +#[derive(Clone, Default)] +pub struct MockReadyQueueSender; + +#[async_trait::async_trait] +impl ReadyQueueSender for MockReadyQueueSender { + async fn send_task_ready( + &self, + _rg_id: ResourceGroupId, + _job_id: JobId, + _task_indices: Vec, + ) -> Result<(), InternalError> { + Ok(()) + } + + async fn send_commit_ready( + &self, + _rg_id: ResourceGroupId, + _job_id: JobId, + ) -> Result<(), InternalError> { + Ok(()) + } + + async fn send_cleanup_ready( + &self, + _rg_id: ResourceGroupId, + _job_id: JobId, + ) -> Result<(), InternalError> { + Ok(()) + } +} + +/// A mock DB connector for testing that implements [`DbStorage`]. +#[derive(Clone)] +pub struct MockDbConnector { + pub states: Arc>, + pub errors: Arc>, + pub outputs: Arc>>, + pub session_id: SessionId, +} + +impl Default for MockDbConnector { + fn default() -> Self { + Self { + states: Arc::new(DashMap::new()), + errors: Arc::new(DashMap::new()), + outputs: Arc::new(DashMap::new()), + session_id: 0, + } + } +} + +#[async_trait::async_trait] +impl ExternalJobOrchestration for MockDbConnector { + async fn register( + &self, + _resource_group_id: ResourceGroupId, + _job_submission: &ValidatedJobSubmission, + ) -> Result { + let job_id = JobId::new(); + self.states.insert(job_id, JobState::Ready); + Ok(job_id) + } + + async fn get_state(&self, job_id: JobId) -> Result { + self.states + .get(&job_id) + .map(|v| *v) + .ok_or(DbError::JobNotFound(job_id)) + } + + async fn get_outputs(&self, job_id: JobId) -> Result, DbError> { + self.outputs + .get(&job_id) + .map(|v| v.clone()) + .ok_or(DbError::JobNotFound(job_id)) + } + + async fn get_error(&self, job_id: JobId) -> Result { + self.errors + .get(&job_id) + .map(|v| v.clone()) + .ok_or(DbError::JobNotFound(job_id)) + } +} + +#[async_trait::async_trait] +impl InternalJobOrchestration for MockDbConnector { + async fn start(&self, job_id: JobId) -> Result<(), DbError> { + self.states.insert(job_id, JobState::Running); + Ok(()) + } + + async fn set_state(&self, job_id: JobId, state: JobState) -> Result<(), DbError> { + self.states.insert(job_id, state); + Ok(()) + } + + async fn commit_outputs( + &self, + job_id: JobId, + _job_outputs: Vec, + _has_commit_task: bool, + ) -> Result<(), DbError> { + self.states.insert(job_id, JobState::Succeeded); + Ok(()) + } + + async fn cancel(&self, job_id: JobId, _has_cleanup_task: bool) -> Result<(), DbError> { + self.states.insert(job_id, JobState::Cancelled); + Ok(()) + } + + async fn fail(&self, job_id: JobId, _error_message: String) -> Result<(), DbError> { + self.states.insert(job_id, JobState::Failed); + Ok(()) + } + + async fn delete_expired_terminated_jobs( + &self, + _expire_after_sec: u64, + ) -> Result, DbError> { + Ok(Vec::new()) + } +} + +#[async_trait::async_trait] +impl ResourceGroupManagement for MockDbConnector { + async fn add( + &self, + _external_resource_group_id: String, + _password: Vec, + ) -> Result { + Ok(ResourceGroupId::new()) + } + + async fn verify( + &self, + _resource_group_id: ResourceGroupId, + _password: &[u8], + ) -> Result<(), DbError> { + Ok(()) + } + + async fn delete(&self, _resource_group_id: ResourceGroupId) -> Result<(), DbError> { + Ok(()) + } +} + +#[async_trait::async_trait] +impl ExecutionManagerLivenessManagement for MockDbConnector { + async fn register_execution_manager( + &self, + _ip_address: IpAddr, + ) -> Result { + Ok(ExecutionManagerId::new()) + } + + async fn update_execution_manager_heartbeat( + &self, + _execution_manager_id: ExecutionManagerId, + ) -> Result<(), DbError> { + Ok(()) + } + + async fn is_execution_manager_alive( + &self, + _execution_manager_id: ExecutionManagerId, + ) -> Result { + Ok(true) + } + + async fn get_dead_execution_managers( + &self, + _stale_after_sec: u64, + ) -> Result, DbError> { + Ok(Vec::new()) + } +} + +impl SessionManagement for MockDbConnector { + fn session_id(&self) -> SessionId { + self.session_id + } +} + +impl DbStorage for MockDbConnector {} + +/// A mock task instance pool connector for testing. +#[derive(Clone, Default)] +pub struct MockTaskInstancePoolConnector; + +#[async_trait::async_trait] +impl TaskInstancePoolConnector for MockTaskInstancePoolConnector { + fn get_next_available_task_instance_id(&self) -> TaskInstanceId { + 1 + } + + async fn register_task_instance( + &self, + _tcb: SharedTaskControlBlock, + _registration: TaskInstanceMetadata, + ) -> Result<(), InternalError> { + Ok(()) + } + + async fn register_termination_task_instance( + &self, + _termination_tcb: SharedTerminationTaskControlBlock, + _registration: TaskInstanceMetadata, + ) -> Result<(), InternalError> { + Ok(()) + } +}