diff --git a/AGENTS.md b/AGENTS.md index 4ff1e4d..9252165 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -18,7 +18,7 @@ Before calling any change done, `cargo fmt --all`, the `clippy` line above, and ## Workspace -36 crates organized into six layers, with `goat-protocol` at the bottom of the dependency DAG: +37 crates organized into six layers, with `goat-protocol` at the bottom of the dependency DAG: **Infrastructure** - `goat-protocol` — shared wire contract (`Op`, `Event`, `TaskId`); serde only; leaf. @@ -29,6 +29,7 @@ Before calling any change done, `cargo fmt --all`, the `clippy` line above, and - `goat-wire` — daemon/client wire contract; leaf (depends on `goat-protocol` only). The `ClientFrame`/`ServerFrame` envelope ({`SessionId`/`ClientId`/`seq` + payload `Op`/`Event`}), length-delimited JSON codec (`WireConn`), and protocol-version handshake. `Op`/`Event` bodies are wrapped, never modified. - `goat-daemon` — the resident `goatd` (`goat daemon serve`); machine-wide single daemon holding N live sessions keyed by cwd. Owns the session registry, a single seq-stamping event-log pump per session (stamp→log→fan-out), per-window bounded delivery with disconnect-on-overflow, presence broadcast, idle eviction (kept alive while a turn runs or a window is attached or an Ask/Plan is open), orphaned-turn sweep on startup, and the unix-socket listener (`~/.goat-code/daemon.sock`, 0600). Allocates per-session `TaskId`s and echoes a correlation token. - `goat-client` — thin transport the TUI talks to; auto-spawns the daemon if absent, performs the handshake, opens/reattaches a session, and exposes the same `Op`/`Event` channels the TUI already consumes. Owns the bidirectional `IdMap` (client-local ↔ daemon `TaskId`) and seq-gap resync. +- `goat-remote` — network-facing remote access for the daemon. mTLS over WebSocket: the daemon is a tiny `rcgen` CA, devices pair once over an HTTP `/pair` endpoint (one-time high-entropy code, server cert pinned by QR fingerprint, CSR signed by the CA) and thereafter connect to `/ws` presenting their device client cert. A custom `ClientCertVerifier` validates the chain and checks the cert fingerprint against the live device registry on every handshake (revocation works here, no CRL/OCSP). The TCP listener self-gates: it binds only while at least one device is paired or a pairing code is pending, and winds down otherwise — there is no separate enable flag. Depends on `goat-wire`/`goat-protocol` only; `goat-daemon` supplies a `RemoteHandler` that bridges each authenticated WS connection into the shared connection driver as `ClientOrigin::Remote`. Remote = local trust; only pairing issuance and `StopDaemon` stay local-only. - `goat-update` — executable replacement helper for `goat update`; small CLI-only crate with no app-state ownership. - `goat-worktree` — git-worktree management (`enter`/`list`/`remove`); `enter` resolves and returns the worktree path (the agent cwd is injected explicitly, not via process `set_current_dir` for the engine). diff --git a/Cargo.lock b/Cargo.lock index 7ece219..d6aa86a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -166,6 +166,45 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "asn1-rs" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f43a50ac4fdca5df8e885c21b835997f0a1cdee65494a6847694a98652d9d8" +dependencies = [ + "asn1-rs-derive", + "asn1-rs-impl", + "displaydoc", + "nom 7.1.3", + "num-traits", + "rusticata-macros", + "thiserror 2.0.18", + "time", +] + +[[package]] +name = "asn1-rs-derive" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3109e49b1e4909e9db6515a30c633684d68cdeaa252f215214cb4fa1a5bfee2c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", + "synstructure", +] + +[[package]] +name = "asn1-rs-impl" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "async-broadcast" version = "0.7.2" @@ -467,6 +506,15 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" +[[package]] +name = "bit-vec" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b71798fca2c1fe1086445a7258a4bc81e6e49dcd24c8d0dd9a1e57395b603f51" +dependencies = [ + "serde", +] + [[package]] name = "bit_field" version = "0.10.3" @@ -1109,6 +1157,20 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5729f5117e208430e437df2f4843f5e5952997175992d1414f94c57d61e270b4" +[[package]] +name = "der-parser" +version = "10.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07da5016415d5a3c4dd39b11ed26f915f52fc4e0dc197d87908bc916e51bc1a6" +dependencies = [ + "asn1-rs", + "displaydoc", + "nom 7.1.3", + "num-bigint", + "num-traits", + "rusticata-macros", +] + [[package]] name = "deranged" version = "0.5.8" @@ -1981,12 +2043,18 @@ dependencies = [ "goat-core", "goat-protocol", "goat-providers", + "goat-remote", "goat-store", "goat-wire", + "rcgen", + "rustls", + "rustls-pemfile", "serde_json", "tempfile", "thiserror 2.0.18", "tokio", + "tokio-rustls", + "tokio-tungstenite", "tokio-util", "tracing", ] @@ -2122,6 +2190,31 @@ dependencies = [ "tokio", ] +[[package]] +name = "goat-remote" +version = "0.1.5" +dependencies = [ + "base64", + "futures", + "goat-protocol", + "goat-wire", + "qrcode", + "rand 0.10.1", + "rcgen", + "rustls", + "rustls-pemfile", + "serde", + "serde_json", + "sha2 0.11.0", + "tempfile", + "thiserror 2.0.18", + "tokio", + "tokio-rustls", + "tokio-tungstenite", + "tokio-util", + "tracing", +] + [[package]] name = "goat-sandbox" version = "0.1.5" @@ -3633,6 +3726,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "oid-registry" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12f40cff3dde1b6087cc5d5f5d4d65712f34016a03ed60e9c08dcc392736b5b7" +dependencies = [ + "asn1-rs", +] + [[package]] name = "once_cell" version = "1.21.4" @@ -3768,6 +3870,16 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" +[[package]] +name = "pem" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" +dependencies = [ + "base64", + "serde_core", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -4136,6 +4248,12 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "qrcode" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d68782463e408eb1e668cf6152704bd856c78c5b6417adaee3203d8f4c1fc9ec" + [[package]] name = "quantette" version = "0.5.1" @@ -4533,6 +4651,20 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "rcgen" +version = "0.14.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57f6d249aad744e274e682777a50283a225a32705394ee6d5fcc01efa25e4055" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "x509-parser", + "yasna", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -4747,6 +4879,15 @@ dependencies = [ "semver", ] +[[package]] +name = "rusticata-macros" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632" +dependencies = [ + "nom 7.1.3", +] + [[package]] name = "rustix" version = "0.38.44" @@ -4787,6 +4928,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "rustls-pki-types" version = "1.14.1" @@ -5519,6 +5669,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -6887,6 +7049,24 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea6fc2961e4ef194dcbfe56bb845534d0dc8098940c7e5c012a258bfec6701bd" +[[package]] +name = "x509-parser" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d43b0f71ce057da06bc0851b23ee24f3f86190b07203dd8f567d0b706a185202" +dependencies = [ + "asn1-rs", + "data-encoding", + "der-parser", + "lazy_static", + "nom 7.1.3", + "oid-registry", + "ring", + "rusticata-macros", + "thiserror 2.0.18", + "time", +] + [[package]] name = "xattr" version = "1.6.1" @@ -6987,6 +7167,16 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "yasna" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5f6765e852b9b4dc8e2a76843e4d64d1cea8e79bcde0b6901aea8e7c7f08282" +dependencies = [ + "bit-vec 0.9.1", + "time", +] + [[package]] name = "yoke" version = "0.8.3" diff --git a/Cargo.toml b/Cargo.toml index 8ce8245..fb7a4e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,7 @@ goat-worktree = { path = "crates/goat-worktree" } goat-wire = { path = "crates/goat-wire" } goat-daemon = { path = "crates/goat-daemon" } goat-client = { path = "crates/goat-client" } +goat-remote = { path = "crates/goat-remote" } tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "time", "signal", "net", "io-util"] } tokio-util = { version = "0.7", features = ["rt", "codec"] } @@ -96,6 +97,13 @@ tracing-appender = "0.2" color-eyre = "0.6" thiserror = "2" +tokio-rustls = { version = "0.26", default-features = false, features = ["ring"] } +rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } +rustls-pemfile = "2" +rcgen = { version = "0.14", default-features = false, features = ["ring", "pem", "x509-parser"] } +tokio-tungstenite = "0.28" +qrcode = { version = "0.14", default-features = false } + [workspace.lints.rust] unsafe_code = "forbid" diff --git a/crates/goat-client/src/lib.rs b/crates/goat-client/src/lib.rs index d9c4cec..80f369e 100644 --- a/crates/goat-client/src/lib.rs +++ b/crates/goat-client/src/lib.rs @@ -192,14 +192,14 @@ fn frame_to_event(frame: ServerFrame) -> Option { ServerFrame::Event { event, .. } => Some(event), ServerFrame::Snapshot { target, - entries, + transcript, context_tokens, compaction_threshold, mode, .. } => target.map(|target| Event::ConversationRestored { target, - entries, + entries: transcript, context_tokens, compaction_threshold, mode, @@ -219,7 +219,7 @@ pub async fn status(socket_path: &Path) -> Result, C expect_welcome(&mut conn).await?; conn.send(&ClientFrame::ListSessions).await?; match conn.recv().await? { - ServerFrame::SessionList { sessions } => Ok(sessions), + ServerFrame::Sessions { sessions } => Ok(sessions), _ => Err(ClientError::Handshake), } } @@ -251,6 +251,68 @@ pub async fn kill_session(socket_path: &Path, session: u64) -> Result<(), Client Ok(()) } +pub struct PairingInfo { + pub code: String, + pub server_fingerprint: String, + pub advertised: Vec, +} + +pub async fn pair_device(socket_path: &Path, label: String) -> Result { + let stream = transport::connect(socket_path).await?; + let mut conn: ClientConn = ClientConn::new(stream); + conn.send(&ClientFrame::Hello { + version: PROTOCOL_VERSION, + }) + .await?; + expect_welcome(&mut conn).await?; + conn.send(&ClientFrame::PairDevice { label }).await?; + match conn.recv().await? { + ServerFrame::PairingCode { + code, + server_fingerprint, + advertised, + } => Ok(PairingInfo { + code, + server_fingerprint, + advertised, + }), + ServerFrame::Error { message } => Err(ClientError::OpenFailed(message)), + _ => Err(ClientError::Handshake), + } +} + +pub async fn list_devices(socket_path: &Path) -> Result, ClientError> { + let stream = transport::connect(socket_path).await?; + let mut conn: ClientConn = ClientConn::new(stream); + conn.send(&ClientFrame::Hello { + version: PROTOCOL_VERSION, + }) + .await?; + expect_welcome(&mut conn).await?; + conn.send(&ClientFrame::ListDevices).await?; + match conn.recv().await? { + ServerFrame::Devices { devices } => Ok(devices), + ServerFrame::Error { message } => Err(ClientError::OpenFailed(message)), + _ => Err(ClientError::Handshake), + } +} + +pub async fn revoke_device(socket_path: &Path, device: String) -> Result { + let stream = transport::connect(socket_path).await?; + let mut conn: ClientConn = ClientConn::new(stream); + conn.send(&ClientFrame::Hello { + version: PROTOCOL_VERSION, + }) + .await?; + expect_welcome(&mut conn).await?; + conn.send(&ClientFrame::RevokeDevice { device }).await?; + match conn.recv().await? { + ServerFrame::DeviceRevoked { ok } => Ok(ok), + ServerFrame::Error { message } => Err(ClientError::OpenFailed(message)), + _ => Err(ClientError::Handshake), + } +} + async fn expect_welcome(conn: &mut ClientConn) -> Result<(), ClientError> { match conn.recv().await? { ServerFrame::Welcome { version, .. } => { diff --git a/crates/goat-code/Cargo.toml b/crates/goat-code/Cargo.toml index e818b33..70b46d3 100644 --- a/crates/goat-code/Cargo.toml +++ b/crates/goat-code/Cargo.toml @@ -34,6 +34,7 @@ tar = { workspace = true } flate2 = { workspace = true } tracing-subscriber = { workspace = true } tracing-appender = { workspace = true } +qrcode = { workspace = true } [dev-dependencies] tempfile = { workspace = true } diff --git a/crates/goat-code/src/cli.rs b/crates/goat-code/src/cli.rs index 5043931..7466516 100644 --- a/crates/goat-code/src/cli.rs +++ b/crates/goat-code/src/cli.rs @@ -28,6 +28,20 @@ pub enum Command { Worktree(WorktreeCommand), #[command(subcommand)] Daemon(DaemonCommand), + #[command(subcommand)] + Remote(RemoteCommand), +} + +#[derive(Subcommand)] +pub enum RemoteCommand { + Pair { + #[arg(long, short)] + label: Option, + }, + #[command(visible_alias = "ls")] + Devices, + #[command(visible_alias = "rm")] + Revoke { device: String }, } #[derive(Subcommand)] diff --git a/crates/goat-code/src/main.rs b/crates/goat-code/src/main.rs index 64aea03..8f87317 100644 --- a/crates/goat-code/src/main.rs +++ b/crates/goat-code/src/main.rs @@ -6,7 +6,7 @@ mod update; use clap::Parser; use color_eyre::eyre::eyre; -use crate::cli::{Cli, Command, DaemonCommand, WorktreeCommand}; +use crate::cli::{Cli, Command, DaemonCommand, RemoteCommand, WorktreeCommand}; #[tokio::main] async fn main() -> color_eyre::Result<()> { @@ -46,6 +46,11 @@ async fn main() -> color_eyre::Result<()> { reject_continue(cli.r#continue)?; run_daemon_command(command).await } + Some(Command::Remote(command)) => { + reject_worktree(cli.worktree.as_ref())?; + reject_continue(cli.r#continue)?; + run_remote_command(command).await + } None => run_tui(cli.worktree, cli.r#continue).await, } } @@ -113,10 +118,12 @@ async fn run_daemon_command(command: DaemonCommand) -> color_eyre::Result<()> { .ok_or_else(|| color_eyre::eyre::eyre!(goat_config::HOME_NOT_FOUND))?; let db_path = goat_config::db_path() .ok_or_else(|| color_eyre::eyre::eyre!(goat_config::HOME_NOT_FOUND))?; + let remote = remote_settings()?; goat_daemon::serve(goat_daemon::DaemonConfig { socket_path, auth_path, db_path, + remote, }) .await .map_err(color_eyre::Report::from) @@ -157,3 +164,81 @@ async fn run_daemon_command(command: DaemonCommand) -> color_eyre::Result<()> { } } } + +fn remote_settings() -> color_eyre::Result> { + let config = goat_config::Config::load(); + let Some(remote_dir) = goat_config::remote_dir() else { + return Ok(None); + }; + let bind = config + .remote + .bind + .parse() + .map_err(|e| color_eyre::eyre::eyre!("invalid remote bind address: {e}"))?; + Ok(Some(goat_daemon::RemoteSettings { + remote_dir, + bind, + advertised: config.remote.advertised, + })) +} + +async fn run_remote_command(command: RemoteCommand) -> color_eyre::Result<()> { + let socket_path = goat_config::socket_path() + .ok_or_else(|| color_eyre::eyre::eyre!(goat_config::HOME_NOT_FOUND))?; + match command { + RemoteCommand::Pair { label } => { + let label = label.unwrap_or_else(|| "device".to_owned()); + let info = goat_client::pair_device(&socket_path, label).await?; + println!("pairing code: {}", info.code); + println!("server fingerprint: {}", info.server_fingerprint); + if info.advertised.is_empty() { + println!("advertised address: (none configured)"); + } else { + println!("advertised address: {}", info.advertised.join(", ")); + } + print_pairing_qr(&info); + Ok(()) + } + RemoteCommand::Devices => { + let devices = goat_client::list_devices(&socket_path).await?; + if devices.is_empty() { + println!("no paired devices"); + } else { + for d in devices { + println!("{} [{}] paired_at={}", d.id, d.label, d.paired_at); + } + } + Ok(()) + } + RemoteCommand::Revoke { device } => { + let ok = goat_client::revoke_device(&socket_path, device.clone()).await?; + if ok { + println!("revoked device {device}"); + } else { + println!("no such device: {device}"); + } + Ok(()) + } + } +} + +fn print_pairing_qr(info: &goat_client::PairingInfo) { + let address = info.advertised.first().cloned().unwrap_or_default(); + let payload = format!( + "goat-pair:code={}&fp={}&addr={}", + info.code, info.server_fingerprint, address + ); + match qrcode::QrCode::new(payload.as_bytes()) { + Ok(code) => { + let rendered = code + .render::() + .quiet_zone(true) + .module_dimensions(2, 1) + .build(); + println!("{rendered}"); + } + Err(_) => { + println!("(could not render QR; use the values above)"); + } + } +} diff --git a/crates/goat-config/src/lib.rs b/crates/goat-config/src/lib.rs index ec6bd06..c4949c6 100644 --- a/crates/goat-config/src/lib.rs +++ b/crates/goat-config/src/lib.rs @@ -18,6 +18,23 @@ pub struct Config { pub browser_enabled: bool, pub mouse_capture_enabled: bool, pub plan_shell_without_sandbox: bool, + pub remote: RemoteConfig, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(default)] +pub struct RemoteConfig { + pub bind: String, + pub advertised: Vec, +} + +impl Default for RemoteConfig { + fn default() -> Self { + Self { + bind: "0.0.0.0:4317".to_owned(), + advertised: Vec::new(), + } + } } impl Default for Config { @@ -28,6 +45,7 @@ impl Default for Config { browser_enabled: false, mouse_capture_enabled: true, plan_shell_without_sandbox: false, + remote: RemoteConfig::default(), } } } @@ -111,6 +129,10 @@ pub fn socket_path() -> Option { app_home().map(|home| home.join("daemon.sock")) } +pub fn remote_dir() -> Option { + app_home().map(|home| home.join("remote")) +} + pub fn update_dir() -> Option { app_home().map(|home| home.join("update")) } @@ -137,7 +159,7 @@ pub fn rate_limits_path() -> Option { #[cfg(test)] mod tests { - use super::{Config, ThemeChoice}; + use super::{Config, RemoteConfig, ThemeChoice}; #[test] fn defaults_to_dark() { @@ -163,6 +185,7 @@ mod tests { browser_enabled: true, mouse_capture_enabled: false, plan_shell_without_sandbox: true, + remote: RemoteConfig::default(), }; let raw = serde_json::to_string(&cfg).unwrap(); assert_eq!(Config::from_json(&raw).unwrap(), cfg); diff --git a/crates/goat-daemon/Cargo.toml b/crates/goat-daemon/Cargo.toml index f3d8085..bc42e3d 100644 --- a/crates/goat-daemon/Cargo.toml +++ b/crates/goat-daemon/Cargo.toml @@ -12,6 +12,7 @@ goat-protocol = { workspace = true } goat-core = { workspace = true } goat-config = { workspace = true } goat-wire = { workspace = true } +goat-remote = { workspace = true } goat-agent = { workspace = true } goat-auth = { workspace = true } goat-store = { workspace = true } @@ -25,6 +26,14 @@ thiserror = { workspace = true } [dev-dependencies] tempfile = { workspace = true } +goat-remote = { workspace = true } +tokio-rustls = { workspace = true } +rustls = { workspace = true } +rustls-pemfile = { workspace = true } +rcgen = { workspace = true } +tokio-tungstenite = { workspace = true } +serde_json = { workspace = true } +futures = { workspace = true } [lints] workspace = true diff --git a/crates/goat-daemon/src/conn.rs b/crates/goat-daemon/src/conn.rs index 6e91eca..30255d1 100644 --- a/crates/goat-daemon/src/conn.rs +++ b/crates/goat-daemon/src/conn.rs @@ -1,7 +1,7 @@ use std::path::PathBuf; -use futures::{SinkExt, StreamExt}; -use goat_wire::transport::Stream; +use futures::{Sink, SinkExt, Stream, StreamExt}; +use goat_wire::transport::Stream as LocalStream; use goat_wire::{ClientFrame, PROTOCOL_VERSION, ServerConn, ServerFrame}; use tokio::sync::mpsc; @@ -9,34 +9,75 @@ use crate::manager::Manager; const CLIENT_QUEUE: usize = 1024; +#[derive(Debug, Clone)] +pub(crate) enum ClientOrigin { + Local, + Remote { device: String }, +} + +impl ClientOrigin { + fn is_local(&self) -> bool { + matches!(self, ClientOrigin::Local) + } +} + pub(crate) async fn handle_connection( - stream: Stream, + stream: LocalStream, manager: Manager, shutdown: tokio_util::sync::CancellationToken, ) { - let mut conn: ServerConn = ServerConn::new(stream); + let conn: ServerConn = ServerConn::new(stream); + let (sink, source) = conn.split(); + serve_connection( + sink, + source, + manager, + shutdown, + ClientOrigin::Local, + tokio_util::sync::CancellationToken::new(), + ) + .await; +} + +pub(crate) async fn serve_connection( + sink: Si, + mut source: St, + manager: Manager, + shutdown: tokio_util::sync::CancellationToken, + origin: ClientOrigin, + disconnect: tokio_util::sync::CancellationToken, +) where + Si: Sink + Send + 'static, + St: Stream> + Unpin, +{ + let mut sink = Box::pin(sink); - if let Ok(ClientFrame::Hello { version }) = conn.recv().await { - if version != PROTOCOL_VERSION { - let _ = conn - .send(&ServerFrame::VersionMismatch { + match source.next().await { + Some(Ok(ClientFrame::Hello { version })) if version == PROTOCOL_VERSION => {} + Some(Ok(ClientFrame::Hello { .. })) => { + let _ = sink + .send(ServerFrame::VersionMismatch { daemon_version: PROTOCOL_VERSION, }) .await; return; } - } else { - let _ = conn - .send(&ServerFrame::Error { - message: "expected Hello".to_owned(), - }) - .await; - return; + _ => { + let _ = sink + .send(ServerFrame::Error { + message: "expected Hello".to_owned(), + }) + .await; + return; + } } let client_id = manager.next_client_id(); - if conn - .send(&ServerFrame::Welcome { + if let ClientOrigin::Remote { device } = &origin { + tracing::info!(client = client_id.0, device = %device, "remote client connected"); + } + if sink + .send(ServerFrame::Welcome { version: PROTOCOL_VERSION, client_id, }) @@ -48,9 +89,6 @@ pub(crate) async fn handle_connection( let (out_tx, mut out_rx) = mpsc::channel::(CLIENT_QUEUE); - let (sink, mut source) = conn.split(); - let mut sink = Box::pin(sink); - let writer = tokio::spawn(async move { while let Some(frame) = out_rx.recv().await { if sink.send(frame).await.is_err() { @@ -60,12 +98,18 @@ pub(crate) async fn handle_connection( }); let mut graceful = false; - while let Some(Ok(frame)) = source.next().await { - match dispatch(&manager, client_id, &out_tx, &shutdown, frame).await { - Disposition::Continue => {} - Disposition::Closed => { - graceful = true; - break; + loop { + tokio::select! { + () = disconnect.cancelled() => break, + next = source.next() => { + let Some(Ok(frame)) = next else { break }; + match dispatch(&manager, client_id, &out_tx, &shutdown, &origin, frame).await { + Disposition::Continue => {} + Disposition::Closed => { + graceful = true; + break; + } + } } } } @@ -89,6 +133,7 @@ async fn dispatch( client_id: goat_wire::ClientId, out_tx: &mpsc::Sender, shutdown: &tokio_util::sync::CancellationToken, + origin: &ClientOrigin, frame: ClientFrame, ) -> Disposition { match frame { @@ -130,7 +175,18 @@ async fn dispatch( } ClientFrame::ListSessions => { let sessions = manager.list_sessions().await; - let _ = out_tx.send(ServerFrame::SessionList { sessions }).await; + let _ = out_tx.send(ServerFrame::Sessions { sessions }).await; + Disposition::Continue + } + ClientFrame::ListDirectory { path } => { + match Manager::list_directory(&path) { + Ok(children) => { + let _ = out_tx.send(ServerFrame::Directory { path, children }).await; + } + Err(message) => { + let _ = out_tx.send(ServerFrame::Error { message }).await; + } + } Disposition::Continue } ClientFrame::KillSession { session } => { @@ -139,10 +195,84 @@ async fn dispatch( } Disposition::Continue } + ClientFrame::PairDevice { label } => { + if origin.is_local() { + match manager.pair_device(label).await { + Ok((code, server_fingerprint, advertised)) => { + let _ = out_tx + .send(ServerFrame::PairingCode { + code, + server_fingerprint, + advertised, + }) + .await; + } + Err(message) => { + let _ = out_tx.send(ServerFrame::Error { message }).await; + } + } + } else { + let _ = out_tx + .send(ServerFrame::Error { + message: "pairing is local-only".to_owned(), + }) + .await; + } + Disposition::Continue + } + ClientFrame::ListDevices => { + match manager.list_devices().await { + Ok(devices) => { + let _ = out_tx.send(ServerFrame::Devices { devices }).await; + } + Err(message) => { + let _ = out_tx.send(ServerFrame::Error { message }).await; + } + } + Disposition::Continue + } + ClientFrame::RevokeDevice { device } => { + match manager.revoke_device(&device).await { + Ok(ok) => { + let _ = out_tx.send(ServerFrame::DeviceRevoked { ok }).await; + } + Err(message) => { + let _ = out_tx.send(ServerFrame::Error { message }).await; + } + } + Disposition::Continue + } ClientFrame::StopDaemon => { - shutdown.cancel(); - Disposition::Closed + if origin.is_local() { + shutdown.cancel(); + Disposition::Closed + } else { + let _ = out_tx + .send(ServerFrame::Error { + message: "StopDaemon is local-only".to_owned(), + }) + .await; + Disposition::Continue + } } ClientFrame::Goodbye => Disposition::Closed, } } + +#[cfg(test)] +mod tests { + use super::ClientOrigin; + + #[test] + fn local_origin_is_local() { + assert!(ClientOrigin::Local.is_local()); + } + + #[test] + fn remote_origin_is_not_local() { + let origin = ClientOrigin::Remote { + device: "abc".to_owned(), + }; + assert!(!origin.is_local()); + } +} diff --git a/crates/goat-daemon/src/lib.rs b/crates/goat-daemon/src/lib.rs index 088f844..5b65e39 100644 --- a/crates/goat-daemon/src/lib.rs +++ b/crates/goat-daemon/src/lib.rs @@ -1,5 +1,6 @@ mod conn; mod manager; +mod remote; mod session; use std::path::{Path, PathBuf}; @@ -14,12 +15,21 @@ pub enum DaemonError { Io(#[from] std::io::Error), #[error("a daemon is already running at {0}")] AlreadyRunning(PathBuf), + #[error("remote error: {0}")] + Remote(#[from] goat_remote::RemoteError), } pub struct DaemonConfig { pub socket_path: PathBuf, pub auth_path: PathBuf, pub db_path: PathBuf, + pub remote: Option, +} + +pub struct RemoteSettings { + pub remote_dir: PathBuf, + pub bind: std::net::SocketAddr, + pub advertised: Vec, } pub async fn serve(config: DaemonConfig) -> Result<(), DaemonError> { @@ -29,6 +39,10 @@ pub async fn serve(config: DaemonConfig) -> Result<(), DaemonError> { let shutdown = tokio_util::sync::CancellationToken::new(); tracing::info!(socket = %config.socket_path.display(), "daemon listening"); + if let Some(remote_settings) = config.remote { + spawn_remote(&manager, &shutdown, remote_settings)?; + } + loop { tokio::select! { () = shutdown.cancelled() => { @@ -52,6 +66,35 @@ pub async fn serve(config: DaemonConfig) -> Result<(), DaemonError> { Ok(()) } +fn spawn_remote( + manager: &Manager, + shutdown: &tokio_util::sync::CancellationToken, + settings: RemoteSettings, +) -> Result<(), DaemonError> { + let devices_path = settings.remote_dir.join("devices.json"); + let devices = goat_remote::Devices::load(devices_path)?; + let config = goat_remote::RemoteConfig { + remote_dir: settings.remote_dir, + bind: settings.bind, + advertised: settings.advertised, + }; + let server = goat_remote::RemoteServer::new(config, devices.clone())?; + manager.set_remote( + server.pairing(), + server.devices(), + server.server_fingerprint().to_owned(), + server.advertised().to_vec(), + ); + let handler = remote::handler(manager.clone(), devices, shutdown.clone()); + let shutdown = shutdown.clone(); + tokio::spawn(async move { + if let Err(err) = server.run(handler, shutdown).await { + tracing::warn!(%err, "remote server stopped"); + } + }); + Ok(()) +} + fn bind(socket_path: &Path) -> Result { if transport::exists(socket_path) && transport::probe_alive(socket_path) { return Err(DaemonError::AlreadyRunning(socket_path.to_path_buf())); diff --git a/crates/goat-daemon/src/manager.rs b/crates/goat-daemon/src/manager.rs index f355010..93f41c6 100644 --- a/crates/goat-daemon/src/manager.rs +++ b/crates/goat-daemon/src/manager.rs @@ -25,6 +25,14 @@ struct ManagerInner { sessions: Mutex, next_session: AtomicU64, next_client: AtomicU64, + remote: Mutex>, +} + +struct RemoteControls { + pairing: goat_remote::Pairing, + devices: goat_remote::Devices, + server_fingerprint: String, + advertised: Vec, } impl Manager { @@ -36,10 +44,70 @@ impl Manager { sessions: Mutex::new(HashMap::new()), next_session: AtomicU64::new(1), next_client: AtomicU64::new(1), + remote: Mutex::new(None), }), } } + pub(crate) fn set_remote( + &self, + pairing: goat_remote::Pairing, + devices: goat_remote::Devices, + server_fingerprint: String, + advertised: Vec, + ) { + let inner = self.inner.clone(); + tokio::spawn(async move { + *inner.remote.lock().await = Some(RemoteControls { + pairing, + devices, + server_fingerprint, + advertised, + }); + }); + } + + pub(crate) async fn pair_device( + &self, + label: String, + ) -> Result<(String, String, Vec), String> { + let guard = self.inner.remote.lock().await; + let controls = guard.as_ref().ok_or("remote is not enabled")?; + let code = controls.pairing.mint(label).await; + Ok(( + code, + controls.server_fingerprint.clone(), + controls.advertised.clone(), + )) + } + + pub(crate) async fn list_devices(&self) -> Result, String> { + let guard = self.inner.remote.lock().await; + let controls = guard.as_ref().ok_or("remote is not enabled")?; + let devices = controls + .devices + .list() + .await + .into_iter() + .map(|d| goat_wire::DeviceInfo { + id: d.id, + label: d.label, + paired_at: d.paired_at, + }) + .collect(); + Ok(devices) + } + + pub(crate) async fn revoke_device(&self, id: &str) -> Result { + let guard = self.inner.remote.lock().await; + let controls = guard.as_ref().ok_or("remote is not enabled")?; + controls + .devices + .revoke(id) + .await + .map_err(|e| format!("revoke: {e}")) + } + pub(crate) fn next_client_id(&self) -> ClientId { ClientId(self.inner.next_client.fetch_add(1, Ordering::Relaxed)) } @@ -164,7 +232,7 @@ impl Manager { session, watermark: snap.watermark, target: snap.target, - entries: snap.entries, + transcript: snap.entries, context_tokens: snap.context_tokens, compaction_threshold: snap.compaction_threshold, mode: snap.mode, @@ -290,6 +358,25 @@ impl Manager { ops.send(op).await.map_err(|_| "engine closed".to_owned()) } + pub(crate) fn list_directory(path: &str) -> Result, String> { + let dir = std::fs::read_dir(path).map_err(|e| format!("read_dir: {e}"))?; + let mut children = Vec::new(); + for entry in dir.flatten() { + let name = entry.file_name().to_string_lossy().into_owned(); + let file_type = entry.file_type().map_err(|e| format!("file_type: {e}"))?; + let kind = if file_type.is_symlink() { + goat_wire::DirEntryKind::Symlink + } else if file_type.is_dir() { + goat_wire::DirEntryKind::Directory + } else { + goat_wire::DirEntryKind::File + }; + children.push(goat_wire::DirEntry { name, kind }); + } + children.sort_by(|a, b| a.name.cmp(&b.name)); + Ok(children) + } + pub(crate) async fn list_sessions(&self) -> Vec { let lives: Vec = { let table = self.inner.sessions.lock().await; diff --git a/crates/goat-daemon/src/remote.rs b/crates/goat-daemon/src/remote.rs new file mode 100644 index 0000000..08c7bff --- /dev/null +++ b/crates/goat-daemon/src/remote.rs @@ -0,0 +1,57 @@ +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::time::Duration; + +use goat_remote::{Device, Devices, RemoteHandler, RemoteSink, RemoteStream}; + +use crate::conn::{ClientOrigin, serve_connection}; +use crate::manager::Manager; + +pub(crate) struct DaemonRemoteHandler { + pub(crate) manager: Manager, + pub(crate) devices: Devices, + pub(crate) shutdown: tokio_util::sync::CancellationToken, +} + +impl RemoteHandler for DaemonRemoteHandler { + fn handle( + &self, + device: Device, + sink: RemoteSink, + stream: RemoteStream, + ) -> Pin + Send>> { + let manager = self.manager.clone(); + let shutdown = self.shutdown.clone(); + let devices = self.devices.clone(); + let fingerprint = device.fingerprint.clone(); + let origin = ClientOrigin::Remote { device: device.id }; + let disconnect = tokio_util::sync::CancellationToken::new(); + let watcher = disconnect.clone(); + Box::pin(async move { + let revocation = tokio::spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(2)).await; + if !devices.contains_fingerprint(&fingerprint).await { + watcher.cancel(); + break; + } + } + }); + serve_connection(sink, stream, manager, shutdown, origin, disconnect).await; + revocation.abort(); + }) + } +} + +pub(crate) fn handler( + manager: Manager, + devices: Devices, + shutdown: tokio_util::sync::CancellationToken, +) -> Arc { + Arc::new(DaemonRemoteHandler { + manager, + devices, + shutdown, + }) +} diff --git a/crates/goat-daemon/tests/remote_e2e.rs b/crates/goat-daemon/tests/remote_e2e.rs new file mode 100644 index 0000000..d9e4bea --- /dev/null +++ b/crates/goat-daemon/tests/remote_e2e.rs @@ -0,0 +1,344 @@ +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +use futures::{SinkExt, StreamExt}; +use goat_wire::transport::{self, Stream}; +use goat_wire::{ClientConn, ClientFrame, PROTOCOL_VERSION, ResumeMode, ServerFrame, WireConn}; +use rustls::pki_types::{CertificateDer, ServerName}; +use tokio_rustls::TlsConnector; +use tokio_tungstenite::tungstenite::Message; + +async fn start_remote_daemon(dir: &std::path::Path, port: u16) -> PathBuf { + let socket = dir.join("d.sock"); + let cfg = goat_daemon::DaemonConfig { + socket_path: socket.clone(), + auth_path: dir.join("auth.json"), + db_path: dir.join("store.sqlite"), + remote: Some(goat_daemon::RemoteSettings { + remote_dir: dir.join("remote"), + bind: format!("127.0.0.1:{port}").parse().unwrap(), + advertised: vec!["127.0.0.1".to_owned()], + }), + }; + tokio::spawn(async move { + let _ = goat_daemon::serve(cfg).await; + }); + for _ in 0..100 { + if transport::connect(&socket).await.is_ok() { + return socket; + } + tokio::time::sleep(Duration::from_millis(50)).await; + } + panic!("daemon did not start"); +} + +async fn local_conn(socket: &std::path::Path) -> ClientConn { + let stream = transport::connect(socket).await.unwrap(); + let mut conn: ClientConn = WireConn::new(stream); + conn.send(&ClientFrame::Hello { + version: PROTOCOL_VERSION, + }) + .await + .unwrap(); + match conn.recv().await.unwrap() { + ServerFrame::Welcome { .. } => {} + other => panic!("expected Welcome, got {other:?}"), + } + conn +} + +async fn mint_code(socket: &std::path::Path) -> (String, String) { + let mut conn = local_conn(socket).await; + conn.send(&ClientFrame::PairDevice { + label: "phone".to_owned(), + }) + .await + .unwrap(); + match conn.recv().await.unwrap() { + ServerFrame::PairingCode { + code, + server_fingerprint, + .. + } => (code, server_fingerprint), + other => panic!("expected PairingCode, got {other:?}"), + } +} + +fn make_csr() -> (rcgen::KeyPair, String) { + let key = rcgen::KeyPair::generate().unwrap(); + let mut params = rcgen::CertificateParams::new(vec!["device".to_owned()]).unwrap(); + params + .distinguished_name + .push(rcgen::DnType::CommonName, "device"); + let csr = params.serialize_request(&key).unwrap().pem().unwrap(); + (key, csr) +} + +#[derive(Debug)] +struct PinnedVerifier { + fingerprint: String, + provider: Arc, +} + +impl rustls::client::danger::ServerCertVerifier for PinnedVerifier { + fn verify_server_cert( + &self, + end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp: &[u8], + _now: rustls::pki_types::UnixTime, + ) -> Result { + let got = goat_remote::fingerprint_der(end_entity.as_ref()); + if got == self.fingerprint { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } else { + Err(rustls::Error::General("server pin mismatch".to_owned())) + } + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls12_signature( + message, + cert, + dss, + &self.provider.signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls13_signature( + message, + cert, + dss, + &self.provider.signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + self.provider + .signature_verification_algorithms + .supported_schemes() + } +} + +fn load_certs(pem: &str) -> Vec> { + let mut reader = pem.as_bytes(); + rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .unwrap() +} + +async fn tls_connect( + port: u16, + fingerprint: &str, + client_cert: Option<( + Vec>, + rustls::pki_types::PrivateKeyDer<'static>, + )>, +) -> tokio_rustls::client::TlsStream { + let provider = Arc::new(rustls::crypto::ring::default_provider()); + let verifier = Arc::new(PinnedVerifier { + fingerprint: fingerprint.to_owned(), + provider: provider.clone(), + }); + let builder = rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(verifier); + let config = match client_cert { + Some((certs, key)) => builder.with_client_auth_cert(certs, key).unwrap(), + None => builder.with_no_client_auth(), + }; + let connector = TlsConnector::from(Arc::new(config)); + let domain = ServerName::try_from("127.0.0.1").unwrap(); + for _ in 0..50 { + if let Ok(tcp) = tokio::net::TcpStream::connect(("127.0.0.1", port)).await + && let Ok(tls) = connector.connect(domain.clone(), tcp).await + { + return tls; + } + tokio::time::sleep(Duration::from_millis(50)).await; + } + panic!("could not establish TLS to remote listener"); +} + +async fn pair_device(port: u16, fingerprint: &str, code: &str) -> (String, String) { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + let tls = tls_connect(port, fingerprint, None).await; + let (key, csr) = make_csr(); + let body = serde_json::json!({ "code": code, "csr_pem": csr }).to_string(); + let request = format!( + "POST /pair HTTP/1.1\r\nHost: 127.0.0.1\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + let mut tls = tls; + tls.write_all(request.as_bytes()).await.unwrap(); + let mut response = Vec::new(); + let mut chunk = [0u8; 4096]; + loop { + match tls.read(&mut chunk).await { + Ok(n) if n > 0 => response.extend_from_slice(&chunk[..n]), + _ => break, + } + } + let text = String::from_utf8_lossy(&response); + let body = text.split("\r\n\r\n").nth(1).unwrap_or_default(); + let parsed: serde_json::Value = serde_json::from_str(body).unwrap(); + let device_cert = parsed["device_cert_pem"].as_str().unwrap().to_owned(); + (key.serialize_pem(), device_cert) +} + +fn install_provider() { + let _ = rustls::crypto::ring::default_provider().install_default(); +} + +#[tokio::test] +async fn remote_pair_and_open_session_over_mtls() { + install_provider(); + let dir = tempfile::tempdir().unwrap(); + let port = 47318; + let socket = start_remote_daemon(dir.path(), port).await; + + let (code, fingerprint) = mint_code(&socket).await; + let (key_pem, device_cert_pem) = pair_device(port, &fingerprint, &code).await; + + let certs = load_certs(&device_cert_pem); + let key = rustls_pemfile::private_key(&mut key_pem.as_bytes()) + .unwrap() + .unwrap(); + let tls = tls_connect(port, &fingerprint, Some((certs, key))).await; + + let (mut ws, _resp) = tokio_tungstenite::client_async("ws://127.0.0.1/ws", tls) + .await + .expect("ws upgrade"); + + send_frame( + &mut ws, + &ClientFrame::Hello { + version: PROTOCOL_VERSION, + }, + ) + .await; + match recv_frame(&mut ws).await { + ServerFrame::Welcome { version, .. } => assert_eq!(version, PROTOCOL_VERSION), + other => panic!("expected Welcome, got {other:?}"), + } + + send_frame( + &mut ws, + &ClientFrame::OpenSession { + cwd: dir.path().display().to_string(), + resume: ResumeMode::New, + }, + ) + .await; + match recv_frame(&mut ws).await { + ServerFrame::SessionOpened { .. } => {} + other => panic!("expected SessionOpened, got {other:?}"), + } +} + +#[tokio::test] +async fn revoked_device_cannot_reconnect() { + install_provider(); + let dir = tempfile::tempdir().unwrap(); + let port = 47319; + let socket = start_remote_daemon(dir.path(), port).await; + + let (code, fingerprint) = mint_code(&socket).await; + let (key_pem, device_cert_pem) = pair_device(port, &fingerprint, &code).await; + + let device_id = { + let mut conn = local_conn(&socket).await; + conn.send(&ClientFrame::ListDevices).await.unwrap(); + match conn.recv().await.unwrap() { + ServerFrame::Devices { devices } => devices[0].id.clone(), + other => panic!("expected Devices, got {other:?}"), + } + }; + { + let mut conn = local_conn(&socket).await; + conn.send(&ClientFrame::RevokeDevice { + device: device_id.clone(), + }) + .await + .unwrap(); + match conn.recv().await.unwrap() { + ServerFrame::DeviceRevoked { ok } => assert!(ok), + other => panic!("expected DeviceRevoked, got {other:?}"), + } + } + + let certs = load_certs(&device_cert_pem); + let key = rustls_pemfile::private_key(&mut key_pem.as_bytes()) + .unwrap() + .unwrap(); + let provider = Arc::new(rustls::crypto::ring::default_provider()); + let verifier = Arc::new(PinnedVerifier { + fingerprint: fingerprint.clone(), + provider: provider.clone(), + }); + let config = rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(verifier) + .with_client_auth_cert(certs, key) + .unwrap(); + let connector = TlsConnector::from(Arc::new(config)); + let domain = ServerName::try_from("127.0.0.1").unwrap(); + let outcome: Result<(), std::io::Error> = async { + let tcp = tokio::net::TcpStream::connect(("127.0.0.1", port)).await?; + let tls = connector.connect(domain, tcp).await?; + let (mut ws, _) = tokio_tungstenite::client_async("ws://127.0.0.1/ws", tls) + .await + .map_err(|e| std::io::Error::other(e.to_string()))?; + let hello = serde_json::to_string(&ClientFrame::Hello { + version: PROTOCOL_VERSION, + }) + .unwrap(); + ws.send(Message::Text(hello.into())) + .await + .map_err(|e| std::io::Error::other(e.to_string()))?; + match ws.next().await { + Some(Ok(_)) => Ok(()), + _ => Err(std::io::Error::other("closed")), + } + } + .await; + assert!( + outcome.is_err(), + "revoked device must be refused before any frame exchange" + ); +} + +async fn send_frame(ws: &mut tokio_tungstenite::WebSocketStream, frame: &ClientFrame) +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + let text = serde_json::to_string(frame).unwrap(); + ws.send(Message::Text(text.into())).await.unwrap(); +} + +async fn recv_frame(ws: &mut tokio_tungstenite::WebSocketStream) -> ServerFrame +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + loop { + match ws.next().await.expect("ws closed").unwrap() { + Message::Text(text) => return serde_json::from_str(&text).unwrap(), + Message::Binary(bytes) => return serde_json::from_slice(&bytes).unwrap(), + _ => {} + } + } +} diff --git a/crates/goat-daemon/tests/roundtrip.rs b/crates/goat-daemon/tests/roundtrip.rs index eec8a62..01c1014 100644 --- a/crates/goat-daemon/tests/roundtrip.rs +++ b/crates/goat-daemon/tests/roundtrip.rs @@ -12,6 +12,7 @@ async fn start_daemon(dir: &std::path::Path) -> PathBuf { socket_path: socket.clone(), auth_path: auth, db_path: db, + remote: None, }; tokio::spawn(async move { let _ = goat_daemon::serve(cfg).await; @@ -60,10 +61,10 @@ async fn open_session_and_list() { let mut lister = connect(&socket).await; lister.send(&ClientFrame::ListSessions).await.unwrap(); match lister.recv().await.unwrap() { - ServerFrame::SessionList { sessions } => { + ServerFrame::Sessions { sessions } => { assert!(sessions.iter().any(|s| s.session == session)); } - other => panic!("expected SessionList, got {other:?}"), + other => panic!("expected Sessions, got {other:?}"), } } @@ -176,9 +177,9 @@ async fn kill_session_removes_it_from_the_list() { admin.send(&ClientFrame::ListSessions).await.unwrap(); match admin.recv().await.unwrap() { - ServerFrame::SessionList { sessions } => { + ServerFrame::Sessions { sessions } => { assert!(!sessions.iter().any(|s| s.session == session)); } - other => panic!("expected SessionList, got {other:?}"), + other => panic!("expected Sessions, got {other:?}"), } } diff --git a/crates/goat-remote/Cargo.toml b/crates/goat-remote/Cargo.toml new file mode 100644 index 0000000..aef043a --- /dev/null +++ b/crates/goat-remote/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "goat-remote" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +repository.workspace = true +publish = false + +[dependencies] +goat-protocol = { workspace = true } +goat-wire = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true } +tokio-util = { workspace = true } +futures = { workspace = true } +tracing = { workspace = true } +thiserror = { workspace = true } +tokio-rustls = { workspace = true } +rustls = { workspace = true } +rustls-pemfile = { workspace = true } +rcgen = { workspace = true } +tokio-tungstenite = { workspace = true } +qrcode = { workspace = true } +rand = { workspace = true } +sha2 = { workspace = true } +base64 = { workspace = true } + +[dev-dependencies] +tempfile = { workspace = true } + +[lints] +workspace = true diff --git a/crates/goat-remote/src/ca.rs b/crates/goat-remote/src/ca.rs new file mode 100644 index 0000000..44f8451 --- /dev/null +++ b/crates/goat-remote/src/ca.rs @@ -0,0 +1,188 @@ +use std::path::{Path, PathBuf}; + +use rcgen::{ + BasicConstraints, CertificateParams, CertificateSigningRequestParams, DnType, IsCa, Issuer, + KeyPair, KeyUsagePurpose, SanType, +}; +use sha2::{Digest, Sha256}; + +use crate::RemoteError; + +pub struct Authority { + dir: PathBuf, + issuer: Issuer<'static, KeyPair>, + ca_cert_pem: String, + server_cert_pem: String, + server_key_pem: String, + server_fingerprint: String, +} + +const CA_CERT: &str = "ca.crt"; +const CA_KEY: &str = "ca.key"; +const SERVER_CERT: &str = "server.crt"; +const SERVER_KEY: &str = "server.key"; + +impl Authority { + pub fn load_or_create(dir: &Path, advertised: &[String]) -> Result { + std::fs::create_dir_all(dir)?; + let ca_cert_path = dir.join(CA_CERT); + let ca_key_path = dir.join(CA_KEY); + + let (ca_cert_pem, ca_key_pem) = if ca_cert_path.exists() && ca_key_path.exists() { + ( + std::fs::read_to_string(&ca_cert_path)?, + std::fs::read_to_string(&ca_key_path)?, + ) + } else { + let key = KeyPair::generate()?; + let mut params = CertificateParams::default(); + params + .distinguished_name + .push(DnType::CommonName, "goat-code remote CA"); + params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); + params.key_usages = vec![ + KeyUsagePurpose::KeyCertSign, + KeyUsagePurpose::CrlSign, + KeyUsagePurpose::DigitalSignature, + ]; + let cert = params.self_signed(&key)?; + let cert_pem = cert.pem(); + let key_pem = key.serialize_pem(); + write_secret(&ca_key_path, key_pem.as_bytes())?; + std::fs::write(&ca_cert_path, cert_pem.as_bytes())?; + (cert_pem, key_pem) + }; + + let ca_key = KeyPair::from_pem(&ca_key_pem)?; + let issuer = Issuer::from_ca_cert_pem(&ca_cert_pem, ca_key)?; + + let server_cert_path = dir.join(SERVER_CERT); + let server_key_path = dir.join(SERVER_KEY); + let (server_cert_pem, server_key_pem) = + generate_server_leaf(&server_cert_path, &server_key_path, &issuer, advertised)?; + let server_fingerprint = fingerprint_pem(&server_cert_pem)?; + + Ok(Self { + dir: dir.to_path_buf(), + issuer, + ca_cert_pem, + server_cert_pem, + server_key_pem, + server_fingerprint, + }) + } + + pub fn ca_cert_pem(&self) -> &str { + &self.ca_cert_pem + } + + pub fn server_cert_pem(&self) -> &str { + &self.server_cert_pem + } + + pub fn server_key_pem(&self) -> &str { + &self.server_key_pem + } + + pub fn server_fingerprint(&self) -> &str { + &self.server_fingerprint + } + + pub fn sign_device_csr(&self, csr_pem: &str) -> Result { + let params = CertificateSigningRequestParams::from_pem(csr_pem)?; + let cert = params.signed_by(&self.issuer)?; + let cert_pem = cert.pem(); + let fingerprint = fingerprint_pem(&cert_pem)?; + Ok(SignedDevice { + cert_pem, + fingerprint, + }) + } + + pub fn dir(&self) -> &Path { + &self.dir + } +} + +pub struct SignedDevice { + pub cert_pem: String, + pub fingerprint: String, +} + +fn generate_server_leaf( + cert_path: &Path, + key_path: &Path, + issuer: &Issuer<'static, KeyPair>, + advertised: &[String], +) -> Result<(String, String), RemoteError> { + if cert_path.exists() && key_path.exists() { + return Ok(( + std::fs::read_to_string(cert_path)?, + std::fs::read_to_string(key_path)?, + )); + } + let key = KeyPair::generate()?; + let mut params = CertificateParams::default(); + params + .distinguished_name + .push(DnType::CommonName, "goat-code remote server"); + params.subject_alt_names = advertised.iter().map(|s| san_for(s)).collect(); + let cert = params.signed_by(&key, issuer)?; + let cert_pem = cert.pem(); + let key_pem = key.serialize_pem(); + write_secret(key_path, key_pem.as_bytes())?; + std::fs::write(cert_path, cert_pem.as_bytes())?; + Ok((cert_pem, key_pem)) +} + +fn san_for(value: &str) -> SanType { + if let Ok(ip) = value.parse::() { + SanType::IpAddress(ip) + } else { + SanType::DnsName(value.to_owned().try_into().unwrap_or_else(|_| { + "localhost" + .to_owned() + .try_into() + .expect("localhost is valid") + })) + } +} + +pub fn fingerprint_pem(pem: &str) -> Result { + let mut reader = pem.as_bytes(); + let item = rustls_pemfile::certs(&mut reader) + .next() + .ok_or(RemoteError::Pem)? + .map_err(|_| RemoteError::Pem)?; + Ok(fingerprint_der(item.as_ref())) +} + +pub fn fingerprint_der(der: &[u8]) -> String { + use std::fmt::Write; + let digest = Sha256::digest(der); + let mut out = String::with_capacity(digest.len() * 2); + for byte in digest { + let _ = write!(out, "{byte:02x}"); + } + out +} + +#[cfg(unix)] +fn write_secret(path: &Path, bytes: &[u8]) -> Result<(), RemoteError> { + use std::io::Write; + use std::os::unix::fs::OpenOptionsExt; + let mut file = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .mode(0o600) + .open(path)?; + file.write_all(bytes)?; + Ok(()) +} + +#[cfg(not(unix))] +fn write_secret(path: &Path, bytes: &[u8]) -> Result<(), RemoteError> { + std::fs::write(path, bytes)?; + Ok(()) +} diff --git a/crates/goat-remote/src/devices.rs b/crates/goat-remote/src/devices.rs new file mode 100644 index 0000000..677fb97 --- /dev/null +++ b/crates/goat-remote/src/devices.rs @@ -0,0 +1,111 @@ +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; +use tokio::sync::{Notify, RwLock}; + +use crate::RemoteError; +use crate::verify::Allowlist; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Device { + pub id: String, + pub label: String, + pub fingerprint: String, + pub paired_at: i64, +} + +#[derive(Clone)] +pub struct Devices { + path: PathBuf, + inner: Arc>>, + allowlist: Allowlist, + changed: Arc, +} + +impl Devices { + pub fn load(path: PathBuf) -> Result { + let devices = match std::fs::read(&path) { + Ok(bytes) => serde_json::from_slice::>(&bytes)?, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => Vec::new(), + Err(err) => return Err(err.into()), + }; + let allowlist = Allowlist::default(); + allowlist.replace(devices.iter().map(|d| d.fingerprint.clone())); + Ok(Self { + path, + inner: Arc::new(RwLock::new(devices)), + allowlist, + changed: Arc::new(Notify::new()), + }) + } + + pub fn allowlist(&self) -> Allowlist { + self.allowlist.clone() + } + + pub fn changed(&self) -> Arc { + self.changed.clone() + } + + pub async fn is_empty(&self) -> bool { + self.inner.read().await.is_empty() + } + + pub async fn list(&self) -> Vec { + self.inner.read().await.clone() + } + + pub async fn contains_fingerprint(&self, fingerprint: &str) -> bool { + self.inner + .read() + .await + .iter() + .any(|d| d.fingerprint == fingerprint) + } + + pub async fn find_by_fingerprint(&self, fingerprint: &str) -> Option { + self.inner + .read() + .await + .iter() + .find(|d| d.fingerprint == fingerprint) + .cloned() + } + + pub async fn enroll(&self, device: Device) -> Result<(), RemoteError> { + let mut guard = self.inner.write().await; + guard.retain(|d| d.id != device.id && d.fingerprint != device.fingerprint); + guard.push(device); + persist(&self.path, &guard)?; + self.allowlist + .replace(guard.iter().map(|d| d.fingerprint.clone())); + drop(guard); + self.changed.notify_waiters(); + Ok(()) + } + + pub async fn revoke(&self, id: &str) -> Result { + let mut guard = self.inner.write().await; + let before = guard.len(); + guard.retain(|d| d.id != id); + if guard.len() == before { + return Ok(false); + } + persist(&self.path, &guard)?; + self.allowlist + .replace(guard.iter().map(|d| d.fingerprint.clone())); + drop(guard); + self.changed.notify_waiters(); + Ok(true) + } +} + +fn persist(path: &Path, devices: &[Device]) -> Result<(), RemoteError> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + let bytes = serde_json::to_vec_pretty(devices)?; + std::fs::write(path, bytes)?; + Ok(()) +} diff --git a/crates/goat-remote/src/lib.rs b/crates/goat-remote/src/lib.rs new file mode 100644 index 0000000..1aaf7c2 --- /dev/null +++ b/crates/goat-remote/src/lib.rs @@ -0,0 +1,53 @@ +mod ca; +mod devices; +mod pairing; +mod server; +mod verify; + +use std::future::Future; +use std::path::PathBuf; +use std::pin::Pin; + +use futures::{Sink, Stream}; +use goat_wire::{ClientFrame, ServerFrame, WireError}; + +pub use ca::{Authority, SignedDevice, fingerprint_der, fingerprint_pem}; +pub use devices::{Device, Devices}; +pub use pairing::Pairing; +pub use verify::{Allowlist, DeviceVerifier}; + +#[derive(Debug, thiserror::Error)] +pub enum RemoteError { + #[error("io error: {0}")] + Io(#[from] std::io::Error), + #[error("json error: {0}")] + Json(#[from] serde_json::Error), + #[error("certificate error: {0}")] + Cert(#[from] rcgen::Error), + #[error("tls error: {0}")] + Tls(#[from] rustls::Error), + #[error("pem decode error")] + Pem, + #[error("bind error: {0}")] + Bind(String), +} + +pub type RemoteSink = Pin + Send>>; +pub type RemoteStream = Pin> + Send>>; + +pub trait RemoteHandler: Send + Sync + 'static { + fn handle( + &self, + device: Device, + sink: RemoteSink, + stream: RemoteStream, + ) -> Pin + Send>>; +} + +pub struct RemoteConfig { + pub remote_dir: PathBuf, + pub bind: std::net::SocketAddr, + pub advertised: Vec, +} + +pub use server::RemoteServer; diff --git a/crates/goat-remote/src/pairing.rs b/crates/goat-remote/src/pairing.rs new file mode 100644 index 0000000..b184dc0 --- /dev/null +++ b/crates/goat-remote/src/pairing.rs @@ -0,0 +1,89 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use tokio::sync::{Mutex, Notify}; + +const CODE_BYTES: usize = 16; +const DEFAULT_TTL: Duration = Duration::from_mins(3); + +#[derive(Clone)] +pub struct Pairing { + inner: Arc>>, + ttl: Duration, + changed: Arc, +} + +struct Pending { + label: String, + expires_at: Instant, +} + +pub struct Claim { + pub label: String, +} + +impl Default for Pairing { + fn default() -> Self { + Self::new(DEFAULT_TTL) + } +} + +impl Pairing { + pub fn new(ttl: Duration) -> Self { + Self { + inner: Arc::new(Mutex::new(HashMap::new())), + ttl, + changed: Arc::new(Notify::new()), + } + } + + pub fn changed(&self) -> Arc { + self.changed.clone() + } + + pub async fn has_pending(&self) -> bool { + let mut guard = self.inner.lock().await; + Self::sweep(&mut guard); + !guard.is_empty() + } + + pub async fn mint(&self, label: String) -> String { + let bytes: [u8; CODE_BYTES] = rand::random(); + let code = encode(&bytes); + let mut guard = self.inner.lock().await; + Self::sweep(&mut guard); + guard.insert( + code.clone(), + Pending { + label, + expires_at: Instant::now() + self.ttl, + }, + ); + drop(guard); + self.changed.notify_waiters(); + code + } + + pub async fn claim(&self, code: &str) -> Option { + let mut guard = self.inner.lock().await; + Self::sweep(&mut guard); + let pending = guard.remove(code)?; + if pending.expires_at <= Instant::now() { + return None; + } + Some(Claim { + label: pending.label, + }) + } + + fn sweep(map: &mut HashMap) { + let now = Instant::now(); + map.retain(|_, p| p.expires_at > now); + } +} + +fn encode(bytes: &[u8]) -> String { + use base64::Engine; + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) +} diff --git a/crates/goat-remote/src/server.rs b/crates/goat-remote/src/server.rs new file mode 100644 index 0000000..38b58b7 --- /dev/null +++ b/crates/goat-remote/src/server.rs @@ -0,0 +1,434 @@ +use std::sync::Arc; + +use futures::{SinkExt, StreamExt}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use rustls::server::WebPkiClientVerifier; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpListener; +use tokio_rustls::TlsAcceptor; +use tokio_tungstenite::tungstenite::Message; + +use crate::ca::Authority; +use crate::devices::{Device, Devices}; +use crate::pairing::Pairing; +use crate::verify::DeviceVerifier; +use crate::{RemoteConfig, RemoteError, RemoteHandler, RemoteSink, RemoteStream}; + +pub struct RemoteServer { + authority: Arc, + devices: Devices, + pairing: Pairing, + config: RemoteConfig, +} + +impl RemoteServer { + pub fn new(config: RemoteConfig, devices: Devices) -> Result { + let authority = Authority::load_or_create(&config.remote_dir, &config.advertised)?; + Ok(Self { + authority: Arc::new(authority), + devices, + pairing: Pairing::default(), + config, + }) + } + + pub fn pairing(&self) -> Pairing { + self.pairing.clone() + } + + pub fn devices(&self) -> Devices { + self.devices.clone() + } + + pub fn server_fingerprint(&self) -> &str { + self.authority.server_fingerprint() + } + + pub fn advertised(&self) -> &[String] { + &self.config.advertised + } + + pub async fn run( + self, + handler: Arc, + shutdown: tokio_util::sync::CancellationToken, + ) -> Result<(), RemoteError> + where + H: RemoteHandler, + { + let tls = self.build_tls_config()?; + let acceptor = TlsAcceptor::from(Arc::new(tls)); + let devices_changed = self.devices.changed(); + let pairing_changed = self.pairing.changed(); + let server = Arc::new(self); + + loop { + let should_listen = + !server.devices.is_empty().await || server.pairing.has_pending().await; + if !should_listen { + tokio::select! { + () = shutdown.cancelled() => break, + () = devices_changed.notified() => continue, + () = pairing_changed.notified() => continue, + } + } + + let listener = match TcpListener::bind(server.config.bind).await { + Ok(listener) => listener, + Err(err) => { + tracing::warn!(%err, addr = %server.config.bind, "remote bind failed"); + tokio::select! { + () = shutdown.cancelled() => break, + () = tokio::time::sleep(std::time::Duration::from_secs(5)) => continue, + } + } + }; + tracing::info!(addr = %server.config.bind, "remote listener up"); + + loop { + let wind_down = + !server.devices.is_empty().await || server.pairing.has_pending().await; + if !wind_down { + tracing::info!("no devices or pending pairings; remote listener down"); + break; + } + tokio::select! { + () = shutdown.cancelled() => return Ok(()), + () = devices_changed.notified() => {} + () = pairing_changed.notified() => {} + () = tokio::time::sleep(std::time::Duration::from_secs(30)) => {} + accepted = listener.accept() => { + let Ok((tcp, _peer)) = accepted else { continue }; + let acceptor = acceptor.clone(); + let server = server.clone(); + let handler = handler.clone(); + tokio::spawn(async move { + if let Err(err) = server.serve_one(acceptor, tcp, handler).await { + tracing::debug!(%err, "remote connection ended"); + } + }); + } + } + } + } + Ok(()) + } + + fn build_tls_config(&self) -> Result { + let server_certs = load_certs(self.authority.server_cert_pem())?; + let server_key = load_key(self.authority.server_key_pem())?; + let ca_certs = load_certs(self.authority.ca_cert_pem())?; + + let mut roots = rustls::RootCertStore::empty(); + for cert in ca_certs { + roots.add(cert).map_err(RemoteError::Tls)?; + } + let chain = WebPkiClientVerifier::builder(Arc::new(roots)) + .build() + .map_err(|e| RemoteError::Bind(e.to_string()))?; + let verifier = Arc::new(DeviceVerifier::new(chain, self.devices.allowlist())); + + let config = rustls::ServerConfig::builder() + .with_client_cert_verifier(verifier) + .with_single_cert(server_certs, server_key) + .map_err(RemoteError::Tls)?; + Ok(config) + } + + async fn serve_one( + self: Arc, + acceptor: TlsAcceptor, + tcp: tokio::net::TcpStream, + handler: Arc, + ) -> Result<(), RemoteError> + where + H: RemoteHandler, + { + let mut tls = acceptor.accept(tcp).await?; + let device_fp = tls + .get_ref() + .1 + .peer_certificates() + .and_then(|c| c.first()) + .map(|cert| crate::ca::fingerprint_der(cert.as_ref())); + + let request = read_request_head(&mut tls).await?; + match request.route() { + Route::Pair => self.handle_pair(tls, request).await, + Route::Ws => self.handle_ws(tls, request, device_fp, handler).await, + Route::Unknown => { + write_simple(tls, "404 Not Found").await?; + Ok(()) + } + } + } + + async fn handle_pair(&self, mut tls: S, request: RequestHead) -> Result<(), RemoteError> + where + S: AsyncRead + AsyncWrite + Unpin, + { + let body = read_body(&mut tls, &request).await?; + let req: PairRequest = match serde_json::from_slice(&body) { + Ok(req) => req, + Err(_) => return write_http(&mut tls, 400, b"{\"error\":\"bad request\"}").await, + }; + let Some(claim) = self.pairing.claim(&req.code).await else { + return write_http(&mut tls, 403, b"{\"error\":\"invalid or expired code\"}").await; + }; + let Ok(signed) = self.authority.sign_device_csr(&req.csr_pem) else { + return write_http(&mut tls, 400, b"{\"error\":\"bad csr\"}").await; + }; + let device = Device { + id: short_id(&signed.fingerprint), + label: claim.label, + fingerprint: signed.fingerprint.clone(), + paired_at: now_ms(), + }; + if self.devices.enroll(device).await.is_err() { + return write_http(&mut tls, 500, b"{\"error\":\"enroll failed\"}").await; + } + let response = PairResponse { + device_cert_pem: signed.cert_pem, + ca_cert_pem: self.authority.ca_cert_pem().to_owned(), + }; + let bytes = serde_json::to_vec(&response)?; + write_http(&mut tls, 200, &bytes).await + } + + async fn handle_ws( + &self, + mut tls: S, + request: RequestHead, + device_fp: Option, + handler: Arc, + ) -> Result<(), RemoteError> + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + H: RemoteHandler, + { + let Some(fingerprint) = device_fp else { + return write_http( + &mut tls, + 403, + b"{\"error\":\"client certificate required\"}", + ) + .await; + }; + let Some(device) = self.devices.find_by_fingerprint(&fingerprint).await else { + return write_http(&mut tls, 403, b"{\"error\":\"unknown device\"}").await; + }; + let Some(key) = request.ws_key else { + return write_http(&mut tls, 400, b"{\"error\":\"missing websocket key\"}").await; + }; + let accept = ws_accept_key(&key); + let upgrade = format!( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {accept}\r\n\r\n" + ); + { + use tokio::io::AsyncWriteExt; + tls.write_all(upgrade.as_bytes()).await?; + tls.flush().await?; + } + let mut wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default(); + wsconfig.max_message_size = Some(MAX_WS_MESSAGE); + wsconfig.max_frame_size = Some(MAX_WS_MESSAGE); + let ws = tokio_tungstenite::WebSocketStream::from_raw_socket( + tls, + tokio_tungstenite::tungstenite::protocol::Role::Server, + Some(wsconfig), + ) + .await; + let (sink, stream) = frame_adapter(ws); + handler.handle(device, sink, stream).await; + Ok(()) + } +} + +const MAX_WS_MESSAGE: usize = 8 * 1024 * 1024; + +fn frame_adapter(ws: tokio_tungstenite::WebSocketStream) -> (RemoteSink, RemoteStream) +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + use goat_wire::{ClientFrame, ServerFrame, WireError}; + let (ws_sink, ws_stream) = ws.split(); + let sink = ws_sink + .sink_map_err(|_| WireError::Closed) + .with(|frame: ServerFrame| async move { + let text = serde_json::to_string(&frame).map_err(WireError::Encode)?; + Ok::<_, WireError>(Message::Text(text.into())) + }); + let stream = ws_stream + .filter_map(|item| async move { + match item { + Ok(Message::Text(text)) => { + Some(serde_json::from_str::(&text).map_err(WireError::Decode)) + } + Ok(Message::Binary(bytes)) => { + Some(serde_json::from_slice::(&bytes).map_err(WireError::Decode)) + } + Ok(Message::Close(_)) | Err(_) => Some(Err(WireError::Closed)), + Ok(_) => None, + } + }) + .boxed(); + (Box::pin(sink), stream) +} + +#[derive(serde::Deserialize)] +struct PairRequest { + code: String, + csr_pem: String, +} + +#[derive(serde::Serialize)] +struct PairResponse { + device_cert_pem: String, + ca_cert_pem: String, +} + +enum Route { + Pair, + Ws, + Unknown, +} + +struct RequestHead { + method: String, + path: String, + content_length: usize, + ws_key: Option, +} + +impl RequestHead { + fn route(&self) -> Route { + match (self.method.as_str(), self.path.as_str()) { + ("POST", "/pair") => Route::Pair, + ("GET", "/ws") => Route::Ws, + _ => Route::Unknown, + } + } +} + +async fn read_request_head(reader: &mut S) -> Result +where + S: AsyncRead + Unpin, +{ + use tokio::io::AsyncReadExt; + let mut buf = Vec::with_capacity(1024); + let mut byte = [0u8; 1]; + loop { + let n = reader.read(&mut byte).await?; + if n == 0 { + return Err(RemoteError::Bind( + "connection closed before head".to_owned(), + )); + } + buf.push(byte[0]); + if buf.len() >= 4 && &buf[buf.len() - 4..] == b"\r\n\r\n" { + break; + } + if buf.len() > 16 * 1024 { + return Err(RemoteError::Bind("request head too large".to_owned())); + } + } + let text = String::from_utf8_lossy(&buf); + let mut lines = text.split("\r\n"); + let request_line = lines.next().unwrap_or_default(); + let mut parts = request_line.split_whitespace(); + let method = parts.next().unwrap_or_default().to_owned(); + let path = parts.next().unwrap_or_default().to_owned(); + let mut content_length = 0usize; + let mut ws_key = None; + for line in lines { + if let Some((name, value)) = line.split_once(':') { + let name = name.trim(); + if name.eq_ignore_ascii_case("content-length") { + content_length = value.trim().parse().unwrap_or(0); + } else if name.eq_ignore_ascii_case("sec-websocket-key") { + ws_key = Some(value.trim().to_owned()); + } + } + } + Ok(RequestHead { + method, + path, + content_length, + ws_key, + }) +} + +async fn read_body(reader: &mut S, request: &RequestHead) -> Result, RemoteError> +where + S: AsyncRead + Unpin, +{ + use tokio::io::AsyncReadExt; + if request.content_length == 0 || request.content_length > MAX_WS_MESSAGE { + return Ok(Vec::new()); + } + let mut body = vec![0u8; request.content_length]; + reader.read_exact(&mut body).await?; + Ok(body) +} + +async fn write_http(tls: &mut S, status: u16, body: &[u8]) -> Result<(), RemoteError> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + use tokio::io::AsyncWriteExt; + let reason = match status { + 200 => "OK", + 400 => "Bad Request", + 403 => "Forbidden", + 500 => "Internal Server Error", + _ => "Error", + }; + let header = format!( + "HTTP/1.1 {status} {reason}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n", + body.len() + ); + tls.write_all(header.as_bytes()).await?; + tls.write_all(body).await?; + tls.flush().await?; + Ok(()) +} + +async fn write_simple(mut tls: S, status: &str) -> Result<(), RemoteError> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + use tokio::io::AsyncWriteExt; + let response = format!("HTTP/1.1 {status}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"); + tls.write_all(response.as_bytes()).await?; + tls.flush().await?; + Ok(()) +} + +fn load_certs(pem: &str) -> Result>, RemoteError> { + let mut reader = pem.as_bytes(); + rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .map_err(|_| RemoteError::Pem) +} + +fn load_key(pem: &str) -> Result, RemoteError> { + let mut reader = pem.as_bytes(); + rustls_pemfile::private_key(&mut reader) + .map_err(|_| RemoteError::Pem)? + .ok_or(RemoteError::Pem) +} + +fn now_ms() -> i64 { + use std::time::{SystemTime, UNIX_EPOCH}; + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |d| i64::try_from(d.as_millis()).unwrap_or(i64::MAX)) +} + +fn ws_accept_key(key: &str) -> String { + tokio_tungstenite::tungstenite::handshake::derive_accept_key(key.as_bytes()) +} + +fn short_id(fingerprint: &str) -> String { + fingerprint.chars().take(12).collect() +} diff --git a/crates/goat-remote/src/verify.rs b/crates/goat-remote/src/verify.rs new file mode 100644 index 0000000..1723bd4 --- /dev/null +++ b/crates/goat-remote/src/verify.rs @@ -0,0 +1,95 @@ +use std::sync::Arc; +use std::sync::RwLock; + +use rustls::DistinguishedName; +use rustls::SignatureScheme; +use rustls::pki_types::{CertificateDer, UnixTime}; +use rustls::server::danger::{ClientCertVerified, ClientCertVerifier}; + +use crate::ca::fingerprint_der; + +#[derive(Debug, Clone, Default)] +pub struct Allowlist { + inner: Arc>>, +} + +impl Allowlist { + pub fn replace(&self, fingerprints: impl IntoIterator) { + let mut guard = self.inner.write().expect("allowlist poisoned"); + *guard = fingerprints.into_iter().collect(); + } + + fn contains(&self, fingerprint: &str) -> bool { + self.inner + .read() + .expect("allowlist poisoned") + .contains(fingerprint) + } +} + +#[derive(Debug)] +pub struct DeviceVerifier { + chain: Arc, + allowlist: Allowlist, +} + +impl DeviceVerifier { + pub fn new(chain: Arc, allowlist: Allowlist) -> Self { + Self { chain, allowlist } + } +} + +impl ClientCertVerifier for DeviceVerifier { + fn root_hint_subjects(&self) -> &[DistinguishedName] { + self.chain.root_hint_subjects() + } + + fn verify_client_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + now: UnixTime, + ) -> Result { + let verified = self + .chain + .verify_client_cert(end_entity, intermediates, now)?; + let fingerprint = fingerprint_der(end_entity.as_ref()); + if self.allowlist.contains(&fingerprint) { + Ok(verified) + } else { + Err(rustls::Error::General( + "device certificate is not in the active registry".to_owned(), + )) + } + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.chain.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.chain.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.chain.supported_verify_schemes() + } + + fn offer_client_auth(&self) -> bool { + true + } + + fn client_auth_mandatory(&self) -> bool { + false + } +} diff --git a/crates/goat-remote/tests/remote.rs b/crates/goat-remote/tests/remote.rs new file mode 100644 index 0000000..02b97c9 --- /dev/null +++ b/crates/goat-remote/tests/remote.rs @@ -0,0 +1,81 @@ +use std::time::Duration; + +use goat_remote::{Authority, Devices, Pairing}; + +#[tokio::test] +async fn pairing_code_is_single_use() { + let pairing = Pairing::default(); + let code = pairing.mint("phone".to_owned()).await; + assert!(pairing.claim(&code).await.is_some()); + assert!(pairing.claim(&code).await.is_none()); +} + +#[tokio::test] +async fn pairing_code_expires() { + let pairing = Pairing::new(Duration::from_millis(10)); + let code = pairing.mint("phone".to_owned()).await; + tokio::time::sleep(Duration::from_millis(30)).await; + assert!(pairing.claim(&code).await.is_none()); +} + +#[tokio::test] +async fn pairing_codes_are_distinct_and_long() { + let pairing = Pairing::default(); + let a = pairing.mint("a".to_owned()).await; + let b = pairing.mint("b".to_owned()).await; + assert_ne!(a, b); + assert!(a.len() >= 20); +} + +#[tokio::test] +async fn enroll_then_revoke_updates_allowlist() { + let dir = tempfile::tempdir().unwrap(); + let devices = Devices::load(dir.path().join("devices.json")).unwrap(); + let allow = devices.allowlist(); + let device = goat_remote::Device { + id: "abc123".to_owned(), + label: "phone".to_owned(), + fingerprint: "deadbeef".to_owned(), + paired_at: 1, + }; + devices.enroll(device).await.unwrap(); + assert!(devices.contains_fingerprint("deadbeef").await); + + let reloaded = Devices::load(dir.path().join("devices.json")).unwrap(); + assert!(reloaded.contains_fingerprint("deadbeef").await); + + assert!(devices.revoke("abc123").await.unwrap()); + assert!(!devices.contains_fingerprint("deadbeef").await); + assert!(!devices.revoke("abc123").await.unwrap()); + let _ = allow; +} + +#[test] +fn ca_signs_device_csr_with_consistent_fingerprint() { + let dir = tempfile::tempdir().unwrap(); + let authority = Authority::load_or_create(dir.path(), &["127.0.0.1".to_owned()]).unwrap(); + + let key = rcgen_keypair(); + let csr_pem = build_csr(&key); + let signed = authority.sign_device_csr(&csr_pem).unwrap(); + let recomputed = goat_remote::fingerprint_pem(&signed.cert_pem).unwrap(); + assert_eq!(signed.fingerprint, recomputed); + + let reopened = Authority::load_or_create(dir.path(), &["127.0.0.1".to_owned()]).unwrap(); + assert_eq!( + reopened.server_fingerprint(), + authority.server_fingerprint() + ); +} + +fn rcgen_keypair() -> rcgen::KeyPair { + rcgen::KeyPair::generate().unwrap() +} + +fn build_csr(key: &rcgen::KeyPair) -> String { + let mut params = rcgen::CertificateParams::new(vec!["device".to_owned()]).unwrap(); + params + .distinguished_name + .push(rcgen::DnType::CommonName, "device"); + params.serialize_request(key).unwrap().pem().unwrap() +} diff --git a/crates/goat-wire/src/lib.rs b/crates/goat-wire/src/lib.rs index 3e36ffd..c1027b6 100644 --- a/crates/goat-wire/src/lib.rs +++ b/crates/goat-wire/src/lib.rs @@ -4,8 +4,8 @@ pub mod transport; pub use codec::{WireConn, WireError}; pub use protocol::{ - ClientFrame, ClientId, PROTOCOL_VERSION, ResumeMode, ServerFrame, SessionId, SessionInfo, - SessionLiveState, + ClientFrame, ClientId, DeviceInfo, DirEntry, DirEntryKind, PROTOCOL_VERSION, ResumeMode, + ServerFrame, SessionId, SessionInfo, SessionLiveState, }; pub type ServerConn = WireConn; @@ -69,4 +69,62 @@ mod tests { client.send(&frame).await.unwrap(); assert_eq!(server.recv().await.unwrap(), frame); } + + #[tokio::test] + async fn directory_frames_roundtrip() { + let (a, b) = tokio::io::duplex(64 * 1024); + let mut server: ServerConn<_> = WireConn::new(a); + let mut client: ClientConn<_> = WireConn::new(b); + + let request = ClientFrame::ListDirectory { + path: "/home/me".to_owned(), + }; + client.send(&request).await.unwrap(); + assert_eq!(server.recv().await.unwrap(), request); + + let response = ServerFrame::Directory { + path: "/home/me".to_owned(), + children: vec![ + DirEntry { + name: "src".to_owned(), + kind: DirEntryKind::Directory, + }, + DirEntry { + name: "main.rs".to_owned(), + kind: DirEntryKind::File, + }, + ], + }; + server.send(&response).await.unwrap(); + assert_eq!(client.recv().await.unwrap(), response); + } + + #[tokio::test] + async fn sessions_and_device_frames_roundtrip() { + let (a, b) = tokio::io::duplex(64 * 1024); + let mut server: ServerConn<_> = WireConn::new(a); + let mut client: ClientConn<_> = WireConn::new(b); + + let sessions = ServerFrame::Sessions { + sessions: Vec::new(), + }; + server.send(&sessions).await.unwrap(); + assert_eq!(client.recv().await.unwrap(), sessions); + + let pair = ClientFrame::PairDevice { + label: "phone".to_owned(), + }; + client.send(&pair).await.unwrap(); + assert_eq!(server.recv().await.unwrap(), pair); + + let devices = ServerFrame::Devices { + devices: vec![DeviceInfo { + id: "abc".to_owned(), + label: "phone".to_owned(), + paired_at: 5, + }], + }; + server.send(&devices).await.unwrap(); + assert_eq!(client.recv().await.unwrap(), devices); + } } diff --git a/crates/goat-wire/src/protocol.rs b/crates/goat-wire/src/protocol.rs index 1a95910..c2863dc 100644 --- a/crates/goat-wire/src/protocol.rs +++ b/crates/goat-wire/src/protocol.rs @@ -32,9 +32,19 @@ pub enum ClientFrame { op: Op, }, ListSessions, + ListDirectory { + path: String, + }, KillSession { session: SessionId, }, + PairDevice { + label: String, + }, + ListDevices, + RevokeDevice { + device: String, + }, StopDaemon, Goodbye, } @@ -59,7 +69,7 @@ pub enum ServerFrame { session: SessionId, watermark: u64, target: Option, - entries: Vec, + transcript: Vec, context_tokens: Option, compaction_threshold: Option, mode: goat_protocol::Mode, @@ -69,9 +79,13 @@ pub enum ServerFrame { seq: u64, event: Event, }, - SessionList { + Sessions { sessions: Vec, }, + Directory { + path: String, + children: Vec, + }, CorrelationAssigned { session: SessionId, correlation: u64, @@ -81,6 +95,17 @@ pub enum ServerFrame { session: SessionId, clients: Vec, }, + PairingCode { + code: String, + server_fingerprint: String, + advertised: Vec, + }, + Devices { + devices: Vec, + }, + DeviceRevoked { + ok: bool, + }, Error { message: String, }, @@ -89,6 +114,13 @@ pub enum ServerFrame { }, } +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct DeviceInfo { + pub id: String, + pub label: String, + pub paired_at: i64, +} + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct SessionInfo { pub session: SessionId, @@ -105,3 +137,16 @@ pub enum SessionLiveState { Active, WaitingOnAsk, } + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct DirEntry { + pub name: String, + pub kind: DirEntryKind, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum DirEntryKind { + Directory, + File, + Symlink, +}