diff --git a/rivetkit-rust/packages/rivetkit-core/src/actor/sqlite.rs b/rivetkit-rust/packages/rivetkit-core/src/actor/sqlite.rs index fac2daa9cd..f3fd2ff842 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/actor/sqlite.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/actor/sqlite.rs @@ -1,7 +1,9 @@ use std::collections::HashSet; use std::io::Cursor; -#[cfg(feature = "sqlite-local")] -use std::sync::Arc; +use std::sync::{ + Arc, + atomic::{AtomicBool, Ordering}, +}; #[cfg(feature = "sqlite-local")] use std::time::Duration; @@ -79,6 +81,7 @@ pub struct SqliteDb { /// always sets up sqlite storage under the hood, so handle/actor_id are /// not a reliable signal for whether the user opted in; this flag is. enabled: bool, + shutdown_closed: Arc, #[cfg(feature = "sqlite-local")] // Forced-sync: native SQLite handles are used inside spawn_blocking and // synchronous diagnostic accessors. @@ -115,6 +118,7 @@ impl SqliteDb { generation, backend: select_sqlite_backend(enabled, remote_sqlite), enabled, + shutdown_closed: Default::default(), #[cfg(feature = "sqlite-local")] db: Default::default(), #[cfg(feature = "sqlite-local")] @@ -143,6 +147,7 @@ impl SqliteDb { &self, request: protocol::SqliteGetPagesRequest, ) -> Result { + self.ensure_not_shutdown_closed()?; self.handle()?.sqlite_get_pages(request).await } @@ -150,15 +155,18 @@ impl SqliteDb { &self, request: protocol::SqliteCommitRequest, ) -> Result { + self.ensure_not_shutdown_closed()?; self.handle()?.sqlite_commit(request).await } pub async fn open(&self) -> Result<()> { + self.ensure_not_shutdown_closed()?; match self.backend { SqliteBackend::LocalNative => { #[cfg(feature = "sqlite-local")] { let _open_guard = self.open_lock.lock().await; + self.ensure_not_shutdown_closed()?; if self.db.lock().is_some() { return Ok(()); } @@ -178,6 +186,10 @@ impl SqliteDb { vfs_metrics, ) .await?; + if self.shutdown_closed.load(Ordering::SeqCst) { + native_db.close().await?; + return Err(SqliteRuntimeError::Closed.build()); + } *self.db.lock() = Some(native_db); self.ensure_preload_hint_flush_task()?; Ok(()) @@ -272,6 +284,7 @@ impl SqliteDb { } pub async fn exec(&self, sql: impl Into) -> Result { + self.ensure_not_shutdown_closed()?; let sql = sql.into(); match self.backend { SqliteBackend::LocalNative => self.local_exec(sql).await, @@ -285,6 +298,7 @@ impl SqliteDb { sql: impl Into, params: Option>, ) -> Result { + self.ensure_not_shutdown_closed()?; let sql = sql.into(); match self.backend { SqliteBackend::LocalNative => self.local_query(sql, params).await, @@ -300,6 +314,7 @@ impl SqliteDb { sql: impl Into, params: Option>, ) -> Result { + self.ensure_not_shutdown_closed()?; let sql = sql.into(); match self.backend { SqliteBackend::LocalNative => self.local_run(sql, params).await, @@ -315,6 +330,7 @@ impl SqliteDb { sql: impl Into, params: Option>, ) -> Result { + self.ensure_not_shutdown_closed()?; let sql = sql.into(); match self.backend { SqliteBackend::LocalNative => self.local_execute(sql, params).await, @@ -328,6 +344,7 @@ impl SqliteDb { sql: impl Into, params: Option>, ) -> Result { + self.ensure_not_shutdown_closed()?; let sql = sql.into(); match self.backend { SqliteBackend::LocalNative => self.local_execute_write(sql, params).await, @@ -354,6 +371,9 @@ impl SqliteDb { } pub(crate) async fn cleanup(&self) -> Result<()> { + self.shutdown_closed.store(true, Ordering::SeqCst); + #[cfg(feature = "sqlite-local")] + let _open_guard = self.open_lock.lock().await; #[cfg(feature = "sqlite-local")] { self.stop_preload_hint_flush_task(); @@ -362,6 +382,13 @@ impl SqliteDb { self.close().await } + fn ensure_not_shutdown_closed(&self) -> Result<()> { + if self.shutdown_closed.load(Ordering::SeqCst) { + return Err(SqliteRuntimeError::Closed.build()); + } + Ok(()) + } + #[cfg(feature = "sqlite-local")] fn ensure_preload_hint_flush_task(&self) -> Result<()> { if !sqlite_optimization_flags().preload_hint_flush {