From 8bb181640197123c0c4f7b92518f1e5a959f1f55 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Sun, 3 May 2026 17:42:26 -0700 Subject: [PATCH] fix(gateway): close stopping actor websockets --- .../guard/src/routing/pegboard_gateway/mod.rs | 53 ++++++++++++++++++- engine/packages/pegboard-gateway/src/lib.rs | 24 +++++++-- engine/packages/pegboard-gateway2/src/lib.rs | 24 +++++++-- .../rivetkit-core/src/registry/websocket.rs | 25 +++++++++ 4 files changed, 114 insertions(+), 12 deletions(-) diff --git a/engine/packages/guard/src/routing/pegboard_gateway/mod.rs b/engine/packages/guard/src/routing/pegboard_gateway/mod.rs index 70e7c4ef2e..eb01f63d15 100644 --- a/engine/packages/guard/src/routing/pegboard_gateway/mod.rs +++ b/engine/packages/guard/src/routing/pegboard_gateway/mod.rs @@ -4,9 +4,16 @@ mod resolve_actor_query; use std::{sync::Arc, time::Duration}; use anyhow::Result; +use async_trait::async_trait; +use bytes::Bytes; use gas::{ctx::message::SubscriptionHandle, prelude::*}; -use hyper::header::HeaderName; -use rivet_guard_core::{RouteConfig, RouteTarget, RoutingOutput, request_context::RequestContext}; +use http_body_util::Full; +use hyper::{Request, Response, StatusCode, header::HeaderName}; +use rivet_guard_core::{ + CustomServeTrait, ResponseBody, RouteConfig, RouteTarget, RoutingOutput, WebSocketHandle, + request_context::RequestContext, +}; +use tokio_tungstenite::tungstenite::protocol::frame::{CloseFrame, coding::CloseCode}; use super::{ SEC_WEBSOCKET_PROTOCOL, WS_PROTOCOL_ACTOR, WS_PROTOCOL_BYPASS_CONNECTABLE, WS_PROTOCOL_TOKEN, @@ -30,6 +37,35 @@ const RUNNER_POOL_ERROR_CHECK_INTERVAL: Duration = Duration::from_secs(2); pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor"); +struct StoppingWebSocket; + +#[async_trait] +impl CustomServeTrait for StoppingWebSocket { + async fn handle_request( + &self, + _req: Request>, + _req_ctx: &mut RequestContext, + ) -> Result> { + Ok(Response::builder() + .status(StatusCode::SERVICE_UNAVAILABLE) + .body(ResponseBody::Full(Full::new(Bytes::from_static( + b"Actor is stopping.", + ))))?) + } + + async fn handle_websocket( + &self, + _req_ctx: &mut RequestContext, + _websocket: WebSocketHandle, + _after_hibernation: bool, + ) -> Result> { + Ok(Some(CloseFrame { + code: CloseCode::Error, + reason: "actor.stopping".into(), + })) + } +} + /// Route requests to actor services using path-based routing #[tracing::instrument(skip_all)] pub async fn route_request_path_based( @@ -320,6 +356,7 @@ async fn route_request_inner( actor, stripped_path, bypass_connectable, + req_ctx.is_websocket(), ready_sub2, stopped_sub2, fail_sub2, @@ -335,6 +372,7 @@ async fn route_request_inner( actor, stripped_path, bypass_connectable, + req_ctx.is_websocket(), ready_sub, stopped_sub, fail_sub, @@ -358,6 +396,7 @@ async fn handle_actor_v2( actor: pegboard::ops::actor::get_for_gateway::Output, stripped_path: &str, bypass_connectable: bool, + is_websocket: bool, mut ready_sub: SubscriptionHandle, mut stopped_sub: SubscriptionHandle, mut fail_sub: SubscriptionHandle, @@ -434,6 +473,10 @@ async fn handle_actor_v2( } // Ready timeout _ = tokio::time::sleep(ctx.config().guard().actor_ready_timeout()) => { + if is_websocket && !bypass_connectable && actor.sleeping { + tracing::debug!(?actor_id, "sleeping actor did not become ready before websocket wait timeout"); + return Ok(RoutingOutput::CustomServe(Arc::new(StoppingWebSocket))); + } return Err(errors::ActorReadyTimeout { actor_id }.build()); } } @@ -461,6 +504,7 @@ async fn handle_actor_v1( actor: pegboard::ops::actor::get_for_gateway::Output, stripped_path: &str, bypass_connectable: bool, + is_websocket: bool, mut ready_sub: SubscriptionHandle, mut stopped_sub: SubscriptionHandle, mut fail_sub: SubscriptionHandle, @@ -552,6 +596,7 @@ async fn handle_actor_v1( actor, stripped_path, bypass_connectable, + is_websocket, ready_sub2, stopped_sub2, fail_sub2, @@ -565,6 +610,10 @@ async fn handle_actor_v1( } // Ready timeout _ = tokio::time::sleep(ctx.config().guard().actor_ready_timeout()) => { + if is_websocket && !bypass_connectable && actor.sleeping { + tracing::debug!(?actor_id, "sleeping actor did not become ready before websocket wait timeout"); + return Ok(RoutingOutput::CustomServe(Arc::new(StoppingWebSocket))); + } return Err(errors::ActorReadyTimeout { actor_id }.build()); } } diff --git a/engine/packages/pegboard-gateway/src/lib.rs b/engine/packages/pegboard-gateway/src/lib.rs index 13ee000334..af978f1322 100644 --- a/engine/packages/pegboard-gateway/src/lib.rs +++ b/engine/packages/pegboard-gateway/src/lib.rs @@ -54,6 +54,13 @@ enum LifecycleResult { Aborted, } +fn actor_stopping_close_frame() -> CloseFrame { + CloseFrame { + code: CloseCode::Error, + reason: "actor.stopping".into(), + } +} + pub struct PegboardGateway { ctx: StandaloneCtx, shared_state: SharedState, @@ -303,11 +310,14 @@ impl PegboardGateway { if let Some(msg) = res { match msg { protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketOpen(msg) => { - return anyhow::Ok(msg); + return anyhow::Ok(Ok(msg)); } protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketClose(close) => { tracing::warn!(?close, "websocket closed before opening"); - return Err(WebSocketServiceUnavailable.build()); + return anyhow::Ok(Err(CloseFrame { + code: close.code.map_or(CloseCode::Normal, Into::into), + reason: close.reason.unwrap_or_default().into(), + })); } _ => { tracing::warn!( @@ -325,11 +335,11 @@ impl PegboardGateway { } _ = stopped_sub.next() => { tracing::debug!("actor stopped while waiting for websocket open"); - return Err(WebSocketServiceUnavailable.build()); + return anyhow::Ok(Err(actor_stopping_close_frame())); } _ = drop_rx.changed() => { tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout"); - return Err(WebSocketServiceUnavailable.build()); + return anyhow::Ok(Err(actor_stopping_close_frame())); } } } @@ -343,13 +353,17 @@ impl PegboardGateway { .pegboard() .gateway_websocket_open_timeout_ms(), ); - let open_msg = tokio::time::timeout(websocket_open_timeout, fut) + let open_result = tokio::time::timeout(websocket_open_timeout, fut) .await .map_err(|_| { tracing::warn!("timed out waiting for websocket open from runner"); WebSocketServiceUnavailable.build() })??; + let open_msg = match open_result { + Ok(open_msg) => open_msg, + Err(close_frame) => return Ok(Some(close_frame)), + }; self.shared_state .toggle_hibernation(request_id, open_msg.can_hibernate) diff --git a/engine/packages/pegboard-gateway2/src/lib.rs b/engine/packages/pegboard-gateway2/src/lib.rs index 67392244f6..d0a6f7dfe1 100644 --- a/engine/packages/pegboard-gateway2/src/lib.rs +++ b/engine/packages/pegboard-gateway2/src/lib.rs @@ -50,6 +50,13 @@ enum LifecycleResult { Aborted, } +fn actor_stopping_close_frame() -> CloseFrame { + CloseFrame { + code: CloseCode::Error, + reason: "actor.stopping".into(), + } +} + #[derive(Debug)] enum HibernationLifecycleResult { Continue, @@ -295,11 +302,14 @@ impl PegboardGateway2 { if let Some(msg) = res { match msg { protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => { - return anyhow::Ok(msg); + return anyhow::Ok(Ok(msg)); } protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => { tracing::warn!(?close, "websocket closed before opening"); - return Err(WebSocketServiceUnavailable.build()); + return anyhow::Ok(Err(CloseFrame { + code: close.code.map_or(CloseCode::Normal, Into::into), + reason: close.reason.unwrap_or_default().into(), + })); } _ => { tracing::warn!( @@ -317,11 +327,11 @@ impl PegboardGateway2 { } _ = stopped_sub.next() => { tracing::debug!("actor stopped while waiting for websocket open"); - return Err(WebSocketServiceUnavailable.build()); + return anyhow::Ok(Err(actor_stopping_close_frame())); } _ = drop_rx.changed() => { tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout"); - return Err(WebSocketServiceUnavailable.build()); + return anyhow::Ok(Err(actor_stopping_close_frame())); } } } @@ -335,13 +345,17 @@ impl PegboardGateway2 { .pegboard() .gateway_websocket_open_timeout_ms(), ); - let open_msg = tokio::time::timeout(websocket_open_timeout, fut) + let open_result = tokio::time::timeout(websocket_open_timeout, fut) .await .map_err(|_| { tracing::warn!("timed out waiting for websocket open from envoy"); WebSocketServiceUnavailable.build() })??; + let open_msg = match open_result { + Ok(open_msg) => open_msg, + Err(close_frame) => return Ok(Some(close_frame)), + }; in_flight_req .toggle_hibernatable(open_msg.can_hibernate) diff --git a/rivetkit-rust/packages/rivetkit-core/src/registry/websocket.rs b/rivetkit-rust/packages/rivetkit-core/src/registry/websocket.rs index acdffac059..ec98ff589a 100644 --- a/rivetkit-rust/packages/rivetkit-core/src/registry/websocket.rs +++ b/rivetkit-rust/packages/rivetkit-core/src/registry/websocket.rs @@ -504,6 +504,13 @@ impl RegistryDispatcher { is_restoring_hibernatable: bool, _sender: WebSocketSender, ) -> Result { + if !is_hibernatable && (!instance.ctx.started() || instance.ctx.sleep_requested()) { + return Ok(closing_websocket_handler(1011, "actor.stopping")); + } + if instance.ctx.destroy_requested() { + return Ok(closing_websocket_handler(1011, "actor.destroying")); + } + let conn_params = websocket_conn_params(headers)?; let websocket_request = Request::from_parts( &request.method, @@ -548,6 +555,7 @@ impl RegistryDispatcher { let conn_for_open = conn.clone(); let ctx_for_message = ctx.clone(); let ctx_for_close = ctx.clone(); + let ctx_for_open = ctx.clone(); let ws = WebSocket::new(); let ctx_for_close_event_region = ctx.clone(); ws.configure_close_event_callback_region(Some(Arc::new(move || { @@ -575,6 +583,14 @@ impl RegistryDispatcher { Box::pin(async move { let callback_ctx = ctx.clone(); ctx.with_websocket_callback(|| async move { + if !is_hibernatable + && (!callback_ctx.started() || callback_ctx.sleep_requested()) + { + ws.close(Some(1011), Some("actor.stopping".to_owned())) + .await; + return; + } + if is_hibernatable && maybe_respond_to_raw_hibernatable_ack_state_probe( &ws, @@ -661,9 +677,18 @@ impl RegistryDispatcher { let ws = ws_for_open.clone(); let actor_id = actor_id_for_open.clone(); let dispatch = dispatch.clone(); + let ctx = ctx_for_open.clone(); Box::pin(async move { let close_sender = sender.clone(); ws.configure_sender(sender); + if !is_hibernatable && (!ctx.started() || ctx.sleep_requested()) { + close_sender.close(Some(1011), Some("actor.stopping".to_owned())); + return; + } + if ctx.destroy_requested() { + close_sender.close(Some(1011), Some("actor.destroying".to_owned())); + return; + } let result = dispatch_websocket_open_through_task( &dispatch, dispatch_capacity,