Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions rivetkit-rust/packages/rivetkit-core/src/actor/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1205,11 +1205,16 @@ impl ActorTask {
let is_new = !persisted.actor.has_initialized;
self.ctx.load_persisted_actor(persisted.actor);
self.ctx.load_last_pushed_alarm(persisted.last_pushed_alarm);
self.ctx.set_has_initialized(true);
self.ctx
.persist_state(SaveStateOpts { immediate: true })
.await
.context("persist actor initialization")?;
// New manual-startup runtimes must not persist initialization until the
// runtime startup_ready handshake completes. The runtime preamble owns
// initial state creation.
if !is_new || !self.factory.requires_manual_startup_ready() {
self.ctx.set_has_initialized(true);
self.ctx
.persist_state(SaveStateOpts { immediate: true })
.await
.context("persist actor initialization")?;
}
let init_inspector_token_started_at = Instant::now();
crate::inspector::auth::init_inspector_token_with_preload(
&self.ctx,
Expand All @@ -1234,6 +1239,12 @@ impl ActorTask {
self.transition_to(LifecycleState::Started);
self.spawn_run_handle(is_new).await?;
if is_new {
// Manual-startup runtimes usually mark initialization during their
// preamble. This is the fallback for runtimes that completed startup
// without doing so.
if !self.ctx.persisted_actor().has_initialized {
self.ctx.set_has_initialized(true);
}
self.ctx
.persist_state(SaveStateOpts { immediate: true })
.await
Expand Down
67 changes: 67 additions & 0 deletions rivetkit-rust/packages/rivetkit-core/tests/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2113,6 +2113,73 @@ mod moved_tests {
assert!(ctx.persisted_actor().has_initialized);
}

#[tokio::test]
async fn manual_startup_does_not_mark_initialized_before_runtime_preamble() {
let kv = new_in_memory();
let ctx = new_with_kv(
"actor-manual-startup-init",
"task-manual-startup-init",
Vec::new(),
"local",
kv,
);
let (observed_tx, observed_rx) = oneshot::channel();
let observed_tx = Arc::new(Mutex::new(Some(observed_tx)));
let factory = Arc::new(ActorFactory::new_with_manual_startup_ready(
Default::default(),
move |mut start| {
let observed_tx = observed_tx.clone();
Box::pin(async move {
observed_tx
.lock()
.expect("observed lock poisoned")
.take()
.expect("observed sender should exist")
.send(start.ctx.persisted_actor().has_initialized)
.expect("observed sender should send");
start.ctx.set_state_initial(vec![4, 5, 6]);
start.ctx.set_has_initialized(true);
start
.startup_ready
.take()
.expect("manual runtime should receive startup ready sender")
.send(Ok(()))
.expect("startup ready receiver should exist");

while let Some(event) = start.events.recv().await {
match event {
ActorEvent::SerializeState { reply, .. } => {
reply.send(Ok(vec![StateDelta::ActorState(start.ctx.state())]));
}
ActorEvent::RunGracefulCleanup { reply, .. } => {
reply.send(Ok(()));
}
_ => {}
}
}
Ok(())
})
},
));
let mut task = new_task_with_factory(ctx.clone(), factory);
let (start_tx, start_rx) = oneshot::channel();

task.handle_lifecycle(LifecycleCommand::Start { reply: start_tx })
.await;
start_rx
.await
.expect("start reply should send")
.expect("start should succeed");

assert!(!observed_rx.await.expect("runtime should observe startup"));
assert!(ctx.persisted_actor().has_initialized);
assert_eq!(ctx.state(), vec![4, 5, 6]);

let run_handle = task.run_handle.take().expect("run handle should exist");
run_handle.abort();
let _ = run_handle.await;
}

#[tokio::test]
async fn startup_uses_preloaded_last_pushed_alarm_without_live_kv() {
let _env_guard = test_inspector_env_lock().lock().expect("env lock poisoned");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ async fn run_preamble(
snapshot: Option<Vec<u8>>,
hibernated: Vec<(rivetkit_core::ConnHandle, Vec<u8>)>,
) -> Result<RunHandlerSlot> {
let snapshot = normalize_startup_snapshot(bindings.create_state.is_some(), snapshot);
let is_new = snapshot.is_none();

// Run database migrations before any user lifecycle hook so `c.db` is
Expand Down Expand Up @@ -290,6 +291,18 @@ async fn run_preamble(
Ok(run_handler)
}

fn normalize_startup_snapshot(
has_create_state: bool,
snapshot: Option<Vec<u8>>,
) -> Option<Vec<u8>> {
// Empty state with createState means a previous process persisted
// initialization before the runtime produced initial state.
match snapshot {
Some(bytes) if bytes.is_empty() && has_create_state => None,
other => other,
}
}

fn configure_run_handler(bindings: &CallbackBindings, ctx: &ActorContext) -> RunHandlerSlot {
let run_handler = Arc::new(Mutex::new(None));
let Some(callback) = bindings.run.as_ref().cloned() else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ mod moved_tests {
use std::sync::Arc as StdArc;
use std::time::Duration;

use rivet_error::RivetError as RivetTransportError;
use rivet_error::{RivetError as RivetTransportError, RivetErrorSchema};
use rivetkit_actor_persist::versioned as persist_versioned;
use rivetkit_core::Kv;
use rivetkit_core::actor::state::PERSIST_DATA_KEY;
use tokio::sync::oneshot;
use vbare::OwnedVersionedData;

use super::*;

const PERSIST_DATA_KEY: &[u8] = &[1];

fn test_adapter_config() -> AdapterConfig {
let timeout = Duration::from_secs(1);
AdapterConfig {
Expand Down Expand Up @@ -64,6 +64,20 @@ mod moved_tests {
assert_eq!(error.code(), code);
}

#[test]
fn startup_snapshot_recovery_only_treats_empty_stateful_snapshot_as_new() {
assert_eq!(normalize_startup_snapshot(true, Some(Vec::new())), None);
assert_eq!(
normalize_startup_snapshot(false, Some(Vec::new())),
Some(Vec::new())
);
assert_eq!(
normalize_startup_snapshot(true, Some(vec![1, 2, 3])),
Some(vec![1, 2, 3])
);
assert_eq!(normalize_startup_snapshot(true, None), None);
}

fn schema_ptr(error: &anyhow::Error) -> *const RivetErrorSchema {
error
.chain()
Expand Down
Loading