From f10f8d5ad3cf19fac1828c025e8e305c700f64c2 Mon Sep 17 00:00:00 2001 From: lhear <121179341+lhear@users.noreply.github.com> Date: Sat, 30 May 2026 15:21:02 +0800 Subject: [PATCH] refactor(crypto): use Zeroizing arrays directly instead of AesKey --- src/client/handshake.rs | 19 +++++++++---------- src/crypto/cipher.rs | 33 ++++++++++++++------------------- src/crypto/handshake.rs | 31 +++++++++++++------------------ src/crypto/keys.rs | 2 +- src/crypto/mod.rs | 6 ++++++ src/server/handlers.rs | 18 ++++++++---------- 6 files changed, 51 insertions(+), 58 deletions(-) diff --git a/src/client/handshake.rs b/src/client/handshake.rs index 92e6a31..a169dc1 100644 --- a/src/client/handshake.rs +++ b/src/client/handshake.rs @@ -41,8 +41,8 @@ pub async fn try_pq_connect( let conn_nonce: [u8; 16] = rand::rng().random(); let (upload_key, download_key, target_key) = crypto::derive_connection_keys(master, &conn_nonce); - let upload_cipher = Arc::new(AesFrameCipher::new(upload_key)); - let download_cipher = Arc::new(AesFrameCipher::new(download_key)); + let upload_cipher = Arc::new(AesFrameCipher::new(&upload_key)); + let download_cipher = Arc::new(AesFrameCipher::new(&download_key)); let enc_target = crypto::encrypt_bytes(&target_key, target_host.as_bytes())?; @@ -131,7 +131,6 @@ pub async fn try_pq_connect( utils::race_upload_download(upload_task, download_fut, Some("download failed")).await } -#[allow(clippy::explicit_auto_deref)] pub async fn full_handshake( http_client: &Arc, state: &Arc, @@ -146,8 +145,8 @@ pub async fn full_handshake( let (eph_sk_a, eph_pk_a) = crypto::generate_keypair(); let eph_sk_a = Zeroizing::new(eph_sk_a); let x25519_shared_a = crypto::diffie_hellman(&eph_sk_a, server_pk); - let handshake_key = crypto::derive_handshake_key(&*x25519_shared_a); - let handshake_cipher = AesFrameCipher::new(crypto::AesKey::from(*handshake_key)); + let handshake_key = crypto::derive_handshake_key(&x25519_shared_a); + let handshake_cipher = AesFrameCipher::new(&handshake_key); let (kem_sk, kem_pk) = crypto::generate_mlkem_keypair(); let kem_pk_bytes = kem_pk.to_bytes(); @@ -247,7 +246,7 @@ pub async fn full_handshake( let master = { let ss_mlkem = crypto::mlkem_decapsulate(&kem_sk, &ct); let ss_x25519 = crypto::diffie_hellman(&eph_sk_b, &server_eph_pk); - crypto::derive_initial_master(&*ss_mlkem, &*ss_x25519) + crypto::derive_initial_master(&ss_mlkem, &ss_x25519) }; info!(session_id = %session_id, "handshake complete, master key derived"); @@ -263,13 +262,13 @@ pub async fn full_handshake( let conn_nonce: [u8; 16] = rand::rng().random(); let (upload_key, download_key, target_key) = - crypto::derive_connection_keys(&*master, &conn_nonce); - let upload_cipher = Arc::new(AesFrameCipher::new(upload_key)); - let download_cipher = Arc::new(AesFrameCipher::new(download_key)); + crypto::derive_connection_keys(&master, &conn_nonce); + let upload_cipher = Arc::new(AesFrameCipher::new(&upload_key)); + let download_cipher = Arc::new(AesFrameCipher::new(&download_key)); let enc_target = crypto::encrypt_bytes(&target_key, target_host.as_bytes())?; - let cookie_nonce_key = crypto::derive_cookie_nonce_key(&*master); + let cookie_nonce_key = crypto::derive_cookie_nonce_key(&master); let enc_conn_nonce = crypto::encrypt_bytes(&cookie_nonce_key, &conn_nonce)?; let cookie_val = format!( diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index c868734..4164dbc 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -7,7 +7,6 @@ use rand::Rng; use std::io; use zeroize::Zeroizing; -use super::AesKey; use crate::shaper::FrameCipher; const NONCE_LEN: usize = 12; @@ -59,14 +58,14 @@ fn decrypt_with_cipher(cipher: &Aes256Gcm, data: &[u8]) -> Result> { } #[inline] -pub fn encrypt_bytes(key: &AesKey, data: &[u8]) -> Result> { - let cipher = Aes256Gcm::new(key); +pub fn encrypt_bytes(key_z: &Zeroizing<[u8; 32]>, data: &[u8]) -> Result> { + let cipher = Aes256Gcm::new_from_slice(&**key_z).map_err(|_| anyhow!("invalid key length"))?; encrypt_with_cipher(&cipher, data) } #[inline] -pub fn decrypt_bytes(key: &AesKey, data: &[u8]) -> Result> { - let cipher = Aes256Gcm::new(key); +pub fn decrypt_bytes(key_z: &Zeroizing<[u8; 32]>, data: &[u8]) -> Result> { + let cipher = Aes256Gcm::new_from_slice(&**key_z).map_err(|_| anyhow!("invalid key length"))?; decrypt_with_cipher(&cipher, data) } @@ -78,20 +77,16 @@ pub struct AesFrameCipher { impl Clone for AesFrameCipher { #[inline] fn clone(&self) -> Self { - Self::new(AesKey::from(*self.key)) + Self::new(&self.key) } } impl AesFrameCipher { #[inline] - pub fn new(key: AesKey) -> Self { - let mut key_bytes = [0u8; 32]; - key_bytes.copy_from_slice(key.as_ref()); - let cipher = Aes256Gcm::new(&key); - Self { - key: Zeroizing::new(key_bytes), - cipher, - } + pub fn new(key_z: &Zeroizing<[u8; 32]>) -> Self { + let key = Zeroizing::new(**key_z); + let cipher = Aes256Gcm::new_from_slice(&**key_z).expect("32 bytes is valid for Aes256Gcm"); + Self { key, cipher } } } @@ -112,10 +107,10 @@ mod tests { use super::*; use rand::Rng; - fn random_key() -> AesKey { - let mut bytes = [0u8; 32]; - rand::rng().fill_bytes(&mut bytes); - AesKey::from(bytes) + fn random_key() -> Zeroizing<[u8; 32]> { + let mut bytes = Zeroizing::new([0u8; 32]); + rand::rng().fill_bytes(&mut *bytes); + bytes } #[test] @@ -130,7 +125,7 @@ mod tests { #[test] fn frame_cipher_roundtrip() { let key = random_key(); - let cipher = AesFrameCipher::new(key); + let cipher = AesFrameCipher::new(&key); let data = b"frame data for cipher test"; let ct = cipher.encrypt(data).unwrap(); let pt = cipher.decrypt(&ct).unwrap(); diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 0b48158..c7462b9 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -2,8 +2,6 @@ use hkdf::Hkdf; use sha2::Sha256; use zeroize::Zeroizing; -use super::AesKey; - pub fn derive_handshake_key(shared: &[u8; 32]) -> Zeroizing<[u8; 32]> { let hkdf = Hkdf::::new(None, shared); let mut key = Zeroizing::new([0u8; 32]); @@ -12,8 +10,8 @@ pub fn derive_handshake_key(shared: &[u8; 32]) -> Zeroizing<[u8; 32]> { key } -pub fn derive_initial_master(mlkem_ss: &[u8], x25519_ss: &[u8]) -> Zeroizing<[u8; 32]> { - let mut ikm = Vec::with_capacity(mlkem_ss.len() + x25519_ss.len()); +pub fn derive_initial_master(mlkem_ss: &[u8], x25519_ss: &[u8; 32]) -> Zeroizing<[u8; 32]> { + let mut ikm = Zeroizing::new(Vec::with_capacity(mlkem_ss.len() + x25519_ss.len())); ikm.extend_from_slice(mlkem_ss); ikm.extend_from_slice(x25519_ss); let hkdf = Hkdf::::new(Some(b"initial_master_salt"), &ikm); @@ -23,33 +21,30 @@ pub fn derive_initial_master(mlkem_ss: &[u8], x25519_ss: &[u8]) -> Zeroizing<[u8 master } -pub fn derive_cookie_nonce_key(master: &[u8; 32]) -> AesKey { +pub fn derive_cookie_nonce_key(master: &[u8; 32]) -> Zeroizing<[u8; 32]> { let hkdf = Hkdf::::new(None, master); - let mut key = [0u8; 32]; - hkdf.expand(b"cookie_nonce_key", &mut key) + let mut key = Zeroizing::new([0u8; 32]); + hkdf.expand(b"cookie_nonce_key", &mut *key) .expect("32 bytes is valid for HKDF"); - key.into() + key } -pub fn derive_connection_keys( - master: &[u8; 32], - conn_nonce: &[u8; 16], -) -> (AesKey, AesKey, AesKey) { +pub fn derive_connection_keys(master: &[u8; 32], conn_nonce: &[u8; 16]) -> super::ConnectionKeys { let hkdf = Hkdf::::new(None, master); let mut info = Vec::with_capacity(16 + 15); info.extend_from_slice(conn_nonce); info.extend_from_slice(b"connection_keys"); - let mut buf = [0u8; 96]; - hkdf.expand(&info, &mut buf) + let mut buf = Zeroizing::new([0u8; 96]); + hkdf.expand(&info, &mut *buf) .expect("96 bytes is valid for HKDF"); - let mut upload_key = [0u8; 32]; - let mut download_key = [0u8; 32]; - let mut target_key = [0u8; 32]; + let mut upload_key = Zeroizing::new([0u8; 32]); + let mut download_key = Zeroizing::new([0u8; 32]); + let mut target_key = Zeroizing::new([0u8; 32]); upload_key.copy_from_slice(&buf[..32]); download_key.copy_from_slice(&buf[32..64]); target_key.copy_from_slice(&buf[64..]); - (upload_key.into(), download_key.into(), target_key.into()) + (upload_key, download_key, target_key) } #[cfg(test)] diff --git a/src/crypto/keys.rs b/src/crypto/keys.rs index 71608c5..a03d98b 100644 --- a/src/crypto/keys.rs +++ b/src/crypto/keys.rs @@ -29,7 +29,7 @@ fn encode_fixed_32(bytes: &[u8; X25519_KEY_LEN]) -> String { #[inline] fn decode_fixed_32(s: &str) -> Result<[u8; X25519_KEY_LEN]> { let mut out = [0u8; X25519_KEY_LEN]; - let decoded = URL_SAFE_NO_PAD.decode(s.as_bytes())?; + let decoded = Zeroizing::new(URL_SAFE_NO_PAD.decode(s.as_bytes())?); if decoded.len() != X25519_KEY_LEN { return Err(anyhow!("invalid key length")); } diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index ccf1fc9..aeac139 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -13,3 +13,9 @@ pub use keys::{ }; pub type AesKey = aes_gcm::Key; + +pub type ConnectionKeys = ( + zeroize::Zeroizing<[u8; 32]>, + zeroize::Zeroizing<[u8; 32]>, + zeroize::Zeroizing<[u8; 32]>, +); diff --git a/src/server/handlers.rs b/src/server/handlers.rs index 62254ae..eddf4db 100644 --- a/src/server/handlers.rs +++ b/src/server/handlers.rs @@ -11,7 +11,7 @@ use tracing::{Instrument, debug, info, warn}; use uuid; use zeroize::Zeroizing; -use crate::crypto::{self, AesFrameCipher, AesKey}; +use crate::crypto::{self, AesFrameCipher}; use crate::error::ServerError; use crate::server::constants::{ CONNECT_TIMEOUT, MASTER_EXPIRY, MAX_UPLOAD_BODY_SIZE, UPLOAD_CHANNEL_CAPACITY, @@ -242,7 +242,6 @@ async fn handle_plaintext_download( ) } -#[allow(clippy::explicit_auto_deref)] async fn handle_fresh_handshake( state: Arc, headers: HeaderMap, @@ -271,8 +270,8 @@ async fn handle_fresh_handshake( .ok_or_else(|| ServerError::internal("server private key not configured"))?; let shared_a = crypto::diffie_hellman(private_key, &eph_pk_a); - let handshake_key = crypto::derive_handshake_key(&*shared_a); - let handshake_cipher = AesFrameCipher::new(AesKey::from(*handshake_key)); + let handshake_key = crypto::derive_handshake_key(&shared_a); + let handshake_cipher = AesFrameCipher::new(&handshake_key); let body_bytes = axum::body::to_bytes(body, MAX_UPLOAD_BODY_SIZE) .await @@ -326,7 +325,7 @@ async fn handle_fresh_handshake( let master = { let server_eph_sk = Zeroizing::new(server_eph_sk); let ss_x25519 = crypto::diffie_hellman(&server_eph_sk, &client_eph_pk_b); - crypto::derive_initial_master(&*ss_mlkem, &*ss_x25519) + crypto::derive_initial_master(&ss_mlkem, &ss_x25519) }; let session_id = uuid::Uuid::new_v4().to_string(); @@ -370,7 +369,6 @@ async fn handle_fresh_handshake( Ok(resp) } -#[allow(clippy::explicit_auto_deref)] async fn handle_pq_download( state: Arc, cookie_val: &str, @@ -401,7 +399,7 @@ async fn handle_pq_download( return Err(ServerError::precondition_required("master key expired")); } - let cookie_nonce_key = crypto::derive_cookie_nonce_key(&*master); + let cookie_nonce_key = crypto::derive_cookie_nonce_key(&master); let enc_target = base64::engine::general_purpose::URL_SAFE_NO_PAD .decode(enc_target_b64) @@ -428,15 +426,15 @@ async fn handle_pq_download( drop(entry); let (upload_key, download_key, target_key) = - crypto::derive_connection_keys(&*master, &conn_nonce); + crypto::derive_connection_keys(&master, &conn_nonce); let target_bytes = crypto::decrypt_bytes(&target_key, &enc_target) .map_err(|_| ServerError::bad_request("failed to decrypt target"))?; let target = String::from_utf8(target_bytes) .map_err(|_| ServerError::bad_request("invalid target utf8"))?; - let upload_cipher = Arc::new(AesFrameCipher::new(upload_key)); - let download_cipher: Arc = Arc::new(AesFrameCipher::new(download_key)); + let upload_cipher = Arc::new(AesFrameCipher::new(&upload_key)); + let download_cipher: Arc = Arc::new(AesFrameCipher::new(&download_key)); let (host, port_str) = target .rsplit_once(':')