Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions src/client/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ pub async fn try_pq_connect(
let conn_nonce: [u8; 16] = rand::rng().random();
let (upload_key, download_key, target_key) =
crypto::derive_connection_keys(master, &conn_nonce);
let upload_cipher = Arc::new(AesFrameCipher::new(upload_key));
let download_cipher = Arc::new(AesFrameCipher::new(download_key));
let upload_cipher = Arc::new(AesFrameCipher::new(&upload_key));
let download_cipher = Arc::new(AesFrameCipher::new(&download_key));

let enc_target = crypto::encrypt_bytes(&target_key, target_host.as_bytes())?;

Expand Down Expand Up @@ -131,7 +131,6 @@ pub async fn try_pq_connect(
utils::race_upload_download(upload_task, download_fut, Some("download failed")).await
}

#[allow(clippy::explicit_auto_deref)]
pub async fn full_handshake(
http_client: &Arc<wreq::Client>,
state: &Arc<SharedState>,
Expand All @@ -146,8 +145,8 @@ pub async fn full_handshake(
let (eph_sk_a, eph_pk_a) = crypto::generate_keypair();
let eph_sk_a = Zeroizing::new(eph_sk_a);
let x25519_shared_a = crypto::diffie_hellman(&eph_sk_a, server_pk);
let handshake_key = crypto::derive_handshake_key(&*x25519_shared_a);
let handshake_cipher = AesFrameCipher::new(crypto::AesKey::from(*handshake_key));
let handshake_key = crypto::derive_handshake_key(&x25519_shared_a);
let handshake_cipher = AesFrameCipher::new(&handshake_key);

let (kem_sk, kem_pk) = crypto::generate_mlkem_keypair();
let kem_pk_bytes = kem_pk.to_bytes();
Expand Down Expand Up @@ -247,7 +246,7 @@ pub async fn full_handshake(
let master = {
let ss_mlkem = crypto::mlkem_decapsulate(&kem_sk, &ct);
let ss_x25519 = crypto::diffie_hellman(&eph_sk_b, &server_eph_pk);
crypto::derive_initial_master(&*ss_mlkem, &*ss_x25519)
crypto::derive_initial_master(&ss_mlkem, &ss_x25519)
};

info!(session_id = %session_id, "handshake complete, master key derived");
Expand All @@ -263,13 +262,13 @@ pub async fn full_handshake(

let conn_nonce: [u8; 16] = rand::rng().random();
let (upload_key, download_key, target_key) =
crypto::derive_connection_keys(&*master, &conn_nonce);
let upload_cipher = Arc::new(AesFrameCipher::new(upload_key));
let download_cipher = Arc::new(AesFrameCipher::new(download_key));
crypto::derive_connection_keys(&master, &conn_nonce);
let upload_cipher = Arc::new(AesFrameCipher::new(&upload_key));
let download_cipher = Arc::new(AesFrameCipher::new(&download_key));

let enc_target = crypto::encrypt_bytes(&target_key, target_host.as_bytes())?;

let cookie_nonce_key = crypto::derive_cookie_nonce_key(&*master);
let cookie_nonce_key = crypto::derive_cookie_nonce_key(&master);
let enc_conn_nonce = crypto::encrypt_bytes(&cookie_nonce_key, &conn_nonce)?;

let cookie_val = format!(
Expand Down
33 changes: 14 additions & 19 deletions src/crypto/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use rand::Rng;
use std::io;
use zeroize::Zeroizing;

use super::AesKey;
use crate::shaper::FrameCipher;

const NONCE_LEN: usize = 12;
Expand Down Expand Up @@ -59,14 +58,14 @@ fn decrypt_with_cipher(cipher: &Aes256Gcm, data: &[u8]) -> Result<Vec<u8>> {
}

#[inline]
pub fn encrypt_bytes(key: &AesKey, data: &[u8]) -> Result<Vec<u8>> {
let cipher = Aes256Gcm::new(key);
pub fn encrypt_bytes(key_z: &Zeroizing<[u8; 32]>, data: &[u8]) -> Result<Vec<u8>> {
let cipher = Aes256Gcm::new_from_slice(&**key_z).map_err(|_| anyhow!("invalid key length"))?;
encrypt_with_cipher(&cipher, data)
}

#[inline]
pub fn decrypt_bytes(key: &AesKey, data: &[u8]) -> Result<Vec<u8>> {
let cipher = Aes256Gcm::new(key);
pub fn decrypt_bytes(key_z: &Zeroizing<[u8; 32]>, data: &[u8]) -> Result<Vec<u8>> {
let cipher = Aes256Gcm::new_from_slice(&**key_z).map_err(|_| anyhow!("invalid key length"))?;
decrypt_with_cipher(&cipher, data)
}

Expand All @@ -78,20 +77,16 @@ pub struct AesFrameCipher {
impl Clone for AesFrameCipher {
#[inline]
fn clone(&self) -> Self {
Self::new(AesKey::from(*self.key))
Self::new(&self.key)
}
}

impl AesFrameCipher {
#[inline]
pub fn new(key: AesKey) -> Self {
let mut key_bytes = [0u8; 32];
key_bytes.copy_from_slice(key.as_ref());
let cipher = Aes256Gcm::new(&key);
Self {
key: Zeroizing::new(key_bytes),
cipher,
}
pub fn new(key_z: &Zeroizing<[u8; 32]>) -> Self {
let key = Zeroizing::new(**key_z);
let cipher = Aes256Gcm::new_from_slice(&**key_z).expect("32 bytes is valid for Aes256Gcm");
Self { key, cipher }
}
}

Expand All @@ -112,10 +107,10 @@ mod tests {
use super::*;
use rand::Rng;

fn random_key() -> AesKey {
let mut bytes = [0u8; 32];
rand::rng().fill_bytes(&mut bytes);
AesKey::from(bytes)
fn random_key() -> Zeroizing<[u8; 32]> {
let mut bytes = Zeroizing::new([0u8; 32]);
rand::rng().fill_bytes(&mut *bytes);
bytes
}

#[test]
Expand All @@ -130,7 +125,7 @@ mod tests {
#[test]
fn frame_cipher_roundtrip() {
let key = random_key();
let cipher = AesFrameCipher::new(key);
let cipher = AesFrameCipher::new(&key);
let data = b"frame data for cipher test";
let ct = cipher.encrypt(data).unwrap();
let pt = cipher.decrypt(&ct).unwrap();
Expand Down
31 changes: 13 additions & 18 deletions src/crypto/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ use hkdf::Hkdf;
use sha2::Sha256;
use zeroize::Zeroizing;

use super::AesKey;

pub fn derive_handshake_key(shared: &[u8; 32]) -> Zeroizing<[u8; 32]> {
let hkdf = Hkdf::<Sha256>::new(None, shared);
let mut key = Zeroizing::new([0u8; 32]);
Expand All @@ -12,8 +10,8 @@ pub fn derive_handshake_key(shared: &[u8; 32]) -> Zeroizing<[u8; 32]> {
key
}

pub fn derive_initial_master(mlkem_ss: &[u8], x25519_ss: &[u8]) -> Zeroizing<[u8; 32]> {
let mut ikm = Vec::with_capacity(mlkem_ss.len() + x25519_ss.len());
pub fn derive_initial_master(mlkem_ss: &[u8], x25519_ss: &[u8; 32]) -> Zeroizing<[u8; 32]> {
let mut ikm = Zeroizing::new(Vec::with_capacity(mlkem_ss.len() + x25519_ss.len()));
ikm.extend_from_slice(mlkem_ss);
ikm.extend_from_slice(x25519_ss);
let hkdf = Hkdf::<Sha256>::new(Some(b"initial_master_salt"), &ikm);
Expand All @@ -23,33 +21,30 @@ pub fn derive_initial_master(mlkem_ss: &[u8], x25519_ss: &[u8]) -> Zeroizing<[u8
master
}

pub fn derive_cookie_nonce_key(master: &[u8; 32]) -> AesKey {
pub fn derive_cookie_nonce_key(master: &[u8; 32]) -> Zeroizing<[u8; 32]> {
let hkdf = Hkdf::<Sha256>::new(None, master);
let mut key = [0u8; 32];
hkdf.expand(b"cookie_nonce_key", &mut key)
let mut key = Zeroizing::new([0u8; 32]);
hkdf.expand(b"cookie_nonce_key", &mut *key)
.expect("32 bytes is valid for HKDF");
key.into()
key
}

pub fn derive_connection_keys(
master: &[u8; 32],
conn_nonce: &[u8; 16],
) -> (AesKey, AesKey, AesKey) {
pub fn derive_connection_keys(master: &[u8; 32], conn_nonce: &[u8; 16]) -> super::ConnectionKeys {
let hkdf = Hkdf::<Sha256>::new(None, master);
let mut info = Vec::with_capacity(16 + 15);
info.extend_from_slice(conn_nonce);
info.extend_from_slice(b"connection_keys");
let mut buf = [0u8; 96];
hkdf.expand(&info, &mut buf)
let mut buf = Zeroizing::new([0u8; 96]);
hkdf.expand(&info, &mut *buf)
.expect("96 bytes is valid for HKDF");

let mut upload_key = [0u8; 32];
let mut download_key = [0u8; 32];
let mut target_key = [0u8; 32];
let mut upload_key = Zeroizing::new([0u8; 32]);
let mut download_key = Zeroizing::new([0u8; 32]);
let mut target_key = Zeroizing::new([0u8; 32]);
upload_key.copy_from_slice(&buf[..32]);
download_key.copy_from_slice(&buf[32..64]);
target_key.copy_from_slice(&buf[64..]);
(upload_key.into(), download_key.into(), target_key.into())
(upload_key, download_key, target_key)
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion src/crypto/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fn encode_fixed_32(bytes: &[u8; X25519_KEY_LEN]) -> String {
#[inline]
fn decode_fixed_32(s: &str) -> Result<[u8; X25519_KEY_LEN]> {
let mut out = [0u8; X25519_KEY_LEN];
let decoded = URL_SAFE_NO_PAD.decode(s.as_bytes())?;
let decoded = Zeroizing::new(URL_SAFE_NO_PAD.decode(s.as_bytes())?);
if decoded.len() != X25519_KEY_LEN {
return Err(anyhow!("invalid key length"));
}
Expand Down
6 changes: 6 additions & 0 deletions src/crypto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@ pub use keys::{
};

pub type AesKey = aes_gcm::Key<aes_gcm::Aes256Gcm>;

pub type ConnectionKeys = (
zeroize::Zeroizing<[u8; 32]>,
zeroize::Zeroizing<[u8; 32]>,
zeroize::Zeroizing<[u8; 32]>,
);
18 changes: 8 additions & 10 deletions src/server/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tracing::{Instrument, debug, info, warn};
use uuid;
use zeroize::Zeroizing;

use crate::crypto::{self, AesFrameCipher, AesKey};
use crate::crypto::{self, AesFrameCipher};
use crate::error::ServerError;
use crate::server::constants::{
CONNECT_TIMEOUT, MASTER_EXPIRY, MAX_UPLOAD_BODY_SIZE, UPLOAD_CHANNEL_CAPACITY,
Expand Down Expand Up @@ -242,7 +242,6 @@ async fn handle_plaintext_download(
)
}

#[allow(clippy::explicit_auto_deref)]
async fn handle_fresh_handshake(
state: Arc<AppState>,
headers: HeaderMap,
Expand Down Expand Up @@ -271,8 +270,8 @@ async fn handle_fresh_handshake(
.ok_or_else(|| ServerError::internal("server private key not configured"))?;

let shared_a = crypto::diffie_hellman(private_key, &eph_pk_a);
let handshake_key = crypto::derive_handshake_key(&*shared_a);
let handshake_cipher = AesFrameCipher::new(AesKey::from(*handshake_key));
let handshake_key = crypto::derive_handshake_key(&shared_a);
let handshake_cipher = AesFrameCipher::new(&handshake_key);

let body_bytes = axum::body::to_bytes(body, MAX_UPLOAD_BODY_SIZE)
.await
Expand Down Expand Up @@ -326,7 +325,7 @@ async fn handle_fresh_handshake(
let master = {
let server_eph_sk = Zeroizing::new(server_eph_sk);
let ss_x25519 = crypto::diffie_hellman(&server_eph_sk, &client_eph_pk_b);
crypto::derive_initial_master(&*ss_mlkem, &*ss_x25519)
crypto::derive_initial_master(&ss_mlkem, &ss_x25519)
};

let session_id = uuid::Uuid::new_v4().to_string();
Expand Down Expand Up @@ -370,7 +369,6 @@ async fn handle_fresh_handshake(
Ok(resp)
}

#[allow(clippy::explicit_auto_deref)]
async fn handle_pq_download(
state: Arc<AppState>,
cookie_val: &str,
Expand Down Expand Up @@ -401,7 +399,7 @@ async fn handle_pq_download(
return Err(ServerError::precondition_required("master key expired"));
}

let cookie_nonce_key = crypto::derive_cookie_nonce_key(&*master);
let cookie_nonce_key = crypto::derive_cookie_nonce_key(&master);

let enc_target = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(enc_target_b64)
Expand All @@ -428,15 +426,15 @@ async fn handle_pq_download(
drop(entry);

let (upload_key, download_key, target_key) =
crypto::derive_connection_keys(&*master, &conn_nonce);
crypto::derive_connection_keys(&master, &conn_nonce);

let target_bytes = crypto::decrypt_bytes(&target_key, &enc_target)
.map_err(|_| ServerError::bad_request("failed to decrypt target"))?;
let target = String::from_utf8(target_bytes)
.map_err(|_| ServerError::bad_request("invalid target utf8"))?;

let upload_cipher = Arc::new(AesFrameCipher::new(upload_key));
let download_cipher: Arc<dyn FrameCipher> = Arc::new(AesFrameCipher::new(download_key));
let upload_cipher = Arc::new(AesFrameCipher::new(&upload_key));
let download_cipher: Arc<dyn FrameCipher> = Arc::new(AesFrameCipher::new(&download_key));

let (host, port_str) = target
.rsplit_once(':')
Expand Down
Loading