Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions engine/packages/guard/src/routing/pegboard_gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,16 @@ mod resolve_actor_query;
use std::{sync::Arc, time::Duration};

use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use gas::{ctx::message::SubscriptionHandle, prelude::*};
use hyper::header::HeaderName;
use rivet_guard_core::{RouteConfig, RouteTarget, RoutingOutput, request_context::RequestContext};
use http_body_util::Full;
use hyper::{Request, Response, StatusCode, header::HeaderName};
use rivet_guard_core::{
CustomServeTrait, ResponseBody, RouteConfig, RouteTarget, RoutingOutput, WebSocketHandle,
request_context::RequestContext,
};
use tokio_tungstenite::tungstenite::protocol::frame::{CloseFrame, coding::CloseCode};

use super::{
SEC_WEBSOCKET_PROTOCOL, WS_PROTOCOL_ACTOR, WS_PROTOCOL_BYPASS_CONNECTABLE, WS_PROTOCOL_TOKEN,
Expand All @@ -30,6 +37,35 @@ const RUNNER_POOL_ERROR_CHECK_INTERVAL: Duration = Duration::from_secs(2);

pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor");

struct StoppingWebSocket;

#[async_trait]
impl CustomServeTrait for StoppingWebSocket {
async fn handle_request(
&self,
_req: Request<Full<Bytes>>,
_req_ctx: &mut RequestContext,
) -> Result<Response<ResponseBody>> {
Ok(Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.body(ResponseBody::Full(Full::new(Bytes::from_static(
b"Actor is stopping.",
))))?)
}

async fn handle_websocket(
&self,
_req_ctx: &mut RequestContext,
_websocket: WebSocketHandle,
_after_hibernation: bool,
) -> Result<Option<CloseFrame>> {
Ok(Some(CloseFrame {
code: CloseCode::Error,
reason: "actor.stopping".into(),
}))
}
}

/// Route requests to actor services using path-based routing
#[tracing::instrument(skip_all)]
pub async fn route_request_path_based(
Expand Down Expand Up @@ -320,6 +356,7 @@ async fn route_request_inner(
actor,
stripped_path,
bypass_connectable,
req_ctx.is_websocket(),
ready_sub2,
stopped_sub2,
fail_sub2,
Expand All @@ -335,6 +372,7 @@ async fn route_request_inner(
actor,
stripped_path,
bypass_connectable,
req_ctx.is_websocket(),
ready_sub,
stopped_sub,
fail_sub,
Expand All @@ -358,6 +396,7 @@ async fn handle_actor_v2(
actor: pegboard::ops::actor::get_for_gateway::Output,
stripped_path: &str,
bypass_connectable: bool,
is_websocket: bool,
mut ready_sub: SubscriptionHandle<pegboard::workflows::actor2::Ready>,
mut stopped_sub: SubscriptionHandle<pegboard::workflows::actor2::Stopped>,
mut fail_sub: SubscriptionHandle<pegboard::workflows::actor2::Failed>,
Expand Down Expand Up @@ -434,6 +473,10 @@ async fn handle_actor_v2(
}
// Ready timeout
_ = tokio::time::sleep(ctx.config().guard().actor_ready_timeout()) => {
if is_websocket && !bypass_connectable && actor.sleeping {
tracing::debug!(?actor_id, "sleeping actor did not become ready before websocket wait timeout");
return Ok(RoutingOutput::CustomServe(Arc::new(StoppingWebSocket)));
}
return Err(errors::ActorReadyTimeout { actor_id }.build());
}
}
Expand Down Expand Up @@ -461,6 +504,7 @@ async fn handle_actor_v1(
actor: pegboard::ops::actor::get_for_gateway::Output,
stripped_path: &str,
bypass_connectable: bool,
is_websocket: bool,
mut ready_sub: SubscriptionHandle<pegboard::workflows::actor::Ready>,
mut stopped_sub: SubscriptionHandle<pegboard::workflows::actor::Stopped>,
mut fail_sub: SubscriptionHandle<pegboard::workflows::actor::Failed>,
Expand Down Expand Up @@ -552,6 +596,7 @@ async fn handle_actor_v1(
actor,
stripped_path,
bypass_connectable,
is_websocket,
ready_sub2,
stopped_sub2,
fail_sub2,
Expand All @@ -565,6 +610,10 @@ async fn handle_actor_v1(
}
// Ready timeout
_ = tokio::time::sleep(ctx.config().guard().actor_ready_timeout()) => {
if is_websocket && !bypass_connectable && actor.sleeping {
tracing::debug!(?actor_id, "sleeping actor did not become ready before websocket wait timeout");
return Ok(RoutingOutput::CustomServe(Arc::new(StoppingWebSocket)));
}
return Err(errors::ActorReadyTimeout { actor_id }.build());
}
}
Expand Down
24 changes: 19 additions & 5 deletions engine/packages/pegboard-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ enum LifecycleResult {
Aborted,
}

fn actor_stopping_close_frame() -> CloseFrame {
CloseFrame {
code: CloseCode::Error,
reason: "actor.stopping".into(),
}
}

pub struct PegboardGateway {
ctx: StandaloneCtx,
shared_state: SharedState,
Expand Down Expand Up @@ -303,11 +310,14 @@ impl PegboardGateway {
if let Some(msg) = res {
match msg {
protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketOpen(msg) => {
return anyhow::Ok(msg);
return anyhow::Ok(Ok(msg));
}
protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketClose(close) => {
tracing::warn!(?close, "websocket closed before opening");
return Err(WebSocketServiceUnavailable.build());
return anyhow::Ok(Err(CloseFrame {
code: close.code.map_or(CloseCode::Normal, Into::into),
reason: close.reason.unwrap_or_default().into(),
}));
}
_ => {
tracing::warn!(
Expand All @@ -325,11 +335,11 @@ impl PegboardGateway {
}
_ = stopped_sub.next() => {
tracing::debug!("actor stopped while waiting for websocket open");
return Err(WebSocketServiceUnavailable.build());
return anyhow::Ok(Err(actor_stopping_close_frame()));
}
_ = drop_rx.changed() => {
tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout");
return Err(WebSocketServiceUnavailable.build());
return anyhow::Ok(Err(actor_stopping_close_frame()));
}
}
}
Expand All @@ -343,13 +353,17 @@ impl PegboardGateway {
.pegboard()
.gateway_websocket_open_timeout_ms(),
);
let open_msg = tokio::time::timeout(websocket_open_timeout, fut)
let open_result = tokio::time::timeout(websocket_open_timeout, fut)
.await
.map_err(|_| {
tracing::warn!("timed out waiting for websocket open from runner");

WebSocketServiceUnavailable.build()
})??;
let open_msg = match open_result {
Ok(open_msg) => open_msg,
Err(close_frame) => return Ok(Some(close_frame)),
};

self.shared_state
.toggle_hibernation(request_id, open_msg.can_hibernate)
Expand Down
24 changes: 19 additions & 5 deletions engine/packages/pegboard-gateway2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ enum LifecycleResult {
Aborted,
}

fn actor_stopping_close_frame() -> CloseFrame {
CloseFrame {
code: CloseCode::Error,
reason: "actor.stopping".into(),
}
}

#[derive(Debug)]
enum HibernationLifecycleResult {
Continue,
Expand Down Expand Up @@ -295,11 +302,14 @@ impl PegboardGateway2 {
if let Some(msg) = res {
match msg {
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketOpen(msg) => {
return anyhow::Ok(msg);
return anyhow::Ok(Ok(msg));
}
protocol::ToRivetTunnelMessageKind::ToRivetWebSocketClose(close) => {
tracing::warn!(?close, "websocket closed before opening");
return Err(WebSocketServiceUnavailable.build());
return anyhow::Ok(Err(CloseFrame {
code: close.code.map_or(CloseCode::Normal, Into::into),
reason: close.reason.unwrap_or_default().into(),
}));
}
_ => {
tracing::warn!(
Expand All @@ -317,11 +327,11 @@ impl PegboardGateway2 {
}
_ = stopped_sub.next() => {
tracing::debug!("actor stopped while waiting for websocket open");
return Err(WebSocketServiceUnavailable.build());
return anyhow::Ok(Err(actor_stopping_close_frame()));
}
_ = drop_rx.changed() => {
tracing::warn!(reason=?drop_rx.borrow(), "websocket open timeout");
return Err(WebSocketServiceUnavailable.build());
return anyhow::Ok(Err(actor_stopping_close_frame()));
}
}
}
Expand All @@ -335,13 +345,17 @@ impl PegboardGateway2 {
.pegboard()
.gateway_websocket_open_timeout_ms(),
);
let open_msg = tokio::time::timeout(websocket_open_timeout, fut)
let open_result = tokio::time::timeout(websocket_open_timeout, fut)
.await
.map_err(|_| {
tracing::warn!("timed out waiting for websocket open from envoy");

WebSocketServiceUnavailable.build()
})??;
let open_msg = match open_result {
Ok(open_msg) => open_msg,
Err(close_frame) => return Ok(Some(close_frame)),
};

in_flight_req
.toggle_hibernatable(open_msg.can_hibernate)
Expand Down
25 changes: 25 additions & 0 deletions rivetkit-rust/packages/rivetkit-core/src/registry/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,13 @@ impl RegistryDispatcher {
is_restoring_hibernatable: bool,
_sender: WebSocketSender,
) -> Result<WebSocketHandler> {
if !is_hibernatable && (!instance.ctx.started() || instance.ctx.sleep_requested()) {
return Ok(closing_websocket_handler(1011, "actor.stopping"));
}
if instance.ctx.destroy_requested() {
return Ok(closing_websocket_handler(1011, "actor.destroying"));
}

let conn_params = websocket_conn_params(headers)?;
let websocket_request = Request::from_parts(
&request.method,
Expand Down Expand Up @@ -548,6 +555,7 @@ impl RegistryDispatcher {
let conn_for_open = conn.clone();
let ctx_for_message = ctx.clone();
let ctx_for_close = ctx.clone();
let ctx_for_open = ctx.clone();
let ws = WebSocket::new();
let ctx_for_close_event_region = ctx.clone();
ws.configure_close_event_callback_region(Some(Arc::new(move || {
Expand Down Expand Up @@ -575,6 +583,14 @@ impl RegistryDispatcher {
Box::pin(async move {
let callback_ctx = ctx.clone();
ctx.with_websocket_callback(|| async move {
if !is_hibernatable
&& (!callback_ctx.started() || callback_ctx.sleep_requested())
{
ws.close(Some(1011), Some("actor.stopping".to_owned()))
.await;
return;
}

if is_hibernatable
&& maybe_respond_to_raw_hibernatable_ack_state_probe(
&ws,
Expand Down Expand Up @@ -661,9 +677,18 @@ impl RegistryDispatcher {
let ws = ws_for_open.clone();
let actor_id = actor_id_for_open.clone();
let dispatch = dispatch.clone();
let ctx = ctx_for_open.clone();
Box::pin(async move {
let close_sender = sender.clone();
ws.configure_sender(sender);
if !is_hibernatable && (!ctx.started() || ctx.sleep_requested()) {
close_sender.close(Some(1011), Some("actor.stopping".to_owned()));
return;
}
if ctx.destroy_requested() {
close_sender.close(Some(1011), Some("actor.destroying".to_owned()));
return;
}
let result = dispatch_websocket_open_through_task(
&dispatch,
dispatch_capacity,
Expand Down
Loading