diff --git a/Cargo.lock b/Cargo.lock index 05e95ab2ec1..a64a3be6a94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1084,6 +1084,7 @@ dependencies = [ "oak_utils", "prost", "prost-types", + "serde", "tokio", "tonic", "tower", @@ -2020,16 +2021,13 @@ version = "0.1.0" dependencies = [ "anyhow", "assert_matches", - "bincode", + "bytes", "log", "prost", "prost-build", "quickcheck", "quickcheck_macros", "ring", - "serde", - "serde-big-array", - "sha2 0.10.1", ] [[package]] @@ -2887,16 +2885,6 @@ dependencies = [ "serde_derive", ] -[[package]] -name = "serde-big-array" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18b20e7752957bbe9661cff4e0bb04d183d0948cdab2ea58cdb9df36a61dfe62" -dependencies = [ - "serde", - "serde_derive", -] - [[package]] name = "serde_cbor" version = "0.11.2" diff --git a/grpc_attestation/Cargo.toml b/grpc_attestation/Cargo.toml index 40856b270ae..870ba4ca372 100644 --- a/grpc_attestation/Cargo.toml +++ b/grpc_attestation/Cargo.toml @@ -13,6 +13,7 @@ oak_remote_attestation = { path = "../remote_attestation/rust/" } oak_functions_abi = { path = "../oak_functions/abi/" } prost = "*" prost-types = "*" +serde = { version = "*", features = ["derive"] } tokio = { version = "*", features = [ "fs", "macros", diff --git a/oak_functions/loader/fuzz/Cargo.lock b/oak_functions/loader/fuzz/Cargo.lock index 8b3cdf50bc2..31d1fa66f22 100644 --- a/oak_functions/loader/fuzz/Cargo.lock +++ b/oak_functions/loader/fuzz/Cargo.lock @@ -96,15 +96,6 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" -[[package]] -name = "block-buffer" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324" -dependencies = [ - "generic-array", -] - [[package]] name = "bumpalo" version = "3.9.1" @@ -194,24 +185,6 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" -[[package]] -name = "cpufeatures" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95059428f66df56b63431fdb4e1947ed2190586af5c5a8a8b71122bdf5a7f469" -dependencies = [ - "libc", -] - -[[package]] -name = "crypto-common" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4600d695eb3f6ce1cd44e6e291adceb2cc3ab12f20a33777ecd0bf6eba34e06" -dependencies = [ - "generic-array", -] - [[package]] name = "derive_arbitrary" version = "1.1.0" @@ -223,16 +196,6 @@ dependencies = [ "syn", ] -[[package]] -name = "digest" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cb780dce4f9a8f5c087362b3a4595936b2019e7c8b30f2c3e9a7e94e6ae9837" -dependencies = [ - "block-buffer", - "crypto-common", -] - [[package]] name = "downcast-rs" version = "1.2.0" @@ -365,16 +328,6 @@ dependencies = [ "slab", ] -[[package]] -name = "generic-array" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd48d33ec7f05fbfa152300fdad764757cbded343c1aa1cff2fbaf4134851803" -dependencies = [ - "typenum", - "version_check", -] - [[package]] name = "getrandom" version = "0.2.4" @@ -398,6 +351,7 @@ dependencies = [ "oak_utils", "prost", "prost-types", + "serde", "tokio", "tonic", "tower", @@ -811,14 +765,11 @@ name = "oak_remote_attestation" version = "0.1.0" dependencies = [ "anyhow", - "bincode", + "bytes", "log", "prost", "prost-build", "ring", - "serde", - "serde-big-array", - "sha2", ] [[package]] @@ -1180,16 +1131,6 @@ dependencies = [ "serde_derive", ] -[[package]] -name = "serde-big-array" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18b20e7752957bbe9661cff4e0bb04d183d0948cdab2ea58cdb9df36a61dfe62" -dependencies = [ - "serde", - "serde_derive", -] - [[package]] name = "serde_derive" version = "1.0.136" @@ -1212,17 +1153,6 @@ dependencies = [ "serde", ] -[[package]] -name = "sha2" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99c3bd8169c58782adad9290a9af5939994036b76187f7b4f0e6de91dbbfc0ec" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - [[package]] name = "signal-hook" version = "0.3.13" @@ -1544,12 +1474,6 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" -[[package]] -name = "typenum" -version = "1.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" - [[package]] name = "unicode-bidi" version = "0.3.7" diff --git a/remote_attestation/rust/Cargo.toml b/remote_attestation/rust/Cargo.toml index 6278d719131..f9eb27aa9ed 100644 --- a/remote_attestation/rust/Cargo.toml +++ b/remote_attestation/rust/Cargo.toml @@ -5,15 +5,16 @@ authors = ["Ivan Petrov "] edition = "2021" license = "Apache-2.0" +[features] +default = [] +std = ["anyhow/std", "prost/std"] + [dependencies] -anyhow = "*" -bincode = "*" +anyhow = { version = "*", default-features = false } +bytes = { version = "*", default-features = false } log = "*" -prost = "*" +prost = { version = "*", default-features = false, features = ["prost-derive"] } ring = "*" -serde = { version = "*", features = ["derive"] } -serde-big-array = { version = "*", features = ["const-generics"] } -sha2 = "*" [build-dependencies] prost-build = "*" diff --git a/remote_attestation/rust/build.rs b/remote_attestation/rust/build.rs index 4ed2f3cbfeb..8a3174c89e8 100644 --- a/remote_attestation/rust/build.rs +++ b/remote_attestation/rust/build.rs @@ -14,11 +14,10 @@ // limitations under the License. // -fn main() -> Result<(), Box> { +fn main() { prost_build::compile_protos( &["remote_attestation/proto/remote_attestation.proto"], &["../.."], ) .expect("Proto compilation failed"); - Ok(()) } diff --git a/remote_attestation/rust/src/crypto.rs b/remote_attestation/rust/src/crypto.rs index 41964bee2c7..105f7214be9 100644 --- a/remote_attestation/rust/src/crypto.rs +++ b/remote_attestation/rust/src/crypto.rs @@ -20,16 +20,17 @@ // protocol. use crate::message::EncryptedData; +use alloc::{format, vec, vec::Vec}; use anyhow::{anyhow, Context}; +use core::convert::TryInto; use ring::{ aead::{self, BoundKey}, agreement, + digest::{digest, SHA256}, hkdf::{Salt, HKDF_SHA256}, rand::{SecureRandom, SystemRandom}, signature::{EcdsaKeyPair, EcdsaSigningAlgorithm, EcdsaVerificationAlgorithm, KeyPair}, }; -use sha2::{digest::Digest, Sha256}; -use std::convert::TryInto; /// Length of the encryption nonce. /// `ring::aead` uses 96-bit (12-byte) nonces. @@ -339,6 +340,7 @@ pub struct Signer { impl Signer { pub fn create() -> anyhow::Result { + // TODO(#2557): Ensure SystemRandom work when building for x86_64 UEFI targets. let rng = ring::rand::SystemRandom::new(); let key_pair_pkcs8 = EcdsaKeyPair::generate_pkcs8(SIGNING_ALGORITHM, &rng) .map_err(|error| anyhow!("Couldn't generate PKCS#8 key pair: {:?}", error))?; @@ -397,11 +399,8 @@ impl SignatureVerifier { /// Computes a SHA-256 digest of `input` and returns it in a form of raw bytes. pub fn get_sha256(input: &[u8]) -> [u8; SHA256_HASH_LENGTH] { - let mut hasher = Sha256::new(); - hasher.update(&input); - hasher - .finalize() - .as_slice() + digest(&SHA256, input) + .as_ref() .try_into() .expect("Incorrect SHA-256 hash length") } diff --git a/remote_attestation/rust/src/handshaker.rs b/remote_attestation/rust/src/handshaker.rs index ecbf7c1c68f..141bfe476a0 100644 --- a/remote_attestation/rust/src/handshaker.rs +++ b/remote_attestation/rust/src/handshaker.rs @@ -34,6 +34,7 @@ use crate::{ }, proto::{AttestationInfo, AttestationReport}, }; +use alloc::{boxed::Box, vec, vec::Vec}; use anyhow::{anyhow, Context}; use prost::Message; @@ -54,8 +55,8 @@ impl Default for ClientHandshakerState { } } -impl std::fmt::Debug for ClientHandshakerState { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl core::fmt::Debug for ClientHandshakerState { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { match self { Self::Initializing => write!(f, "Initializing"), Self::ExpectingServerIdentity(_) => write!(f, "ExpectingServerIdentity"), @@ -81,8 +82,8 @@ impl Default for ServerHandshakerState { } } -impl std::fmt::Debug for ServerHandshakerState { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl core::fmt::Debug for ServerHandshakerState { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { match self { Self::ExpectingClientHello => write!(f, "ExpectingClientHello"), Self::ExpectingClientIdentity(_) => write!(f, "ExpectingClientIdentity"), @@ -131,7 +132,7 @@ impl ClientHandshaker { deserialize_message(message).context("Couldn't deserialize message")?; match deserialized_message { MessageWrapper::ServerIdentity(server_identity) => { - match std::mem::take(&mut self.state) { + match core::mem::take(&mut self.state) { ClientHandshakerState::ExpectingServerIdentity(key_negotiator) => { let client_identity = self .process_server_identity(server_identity, key_negotiator) @@ -380,7 +381,7 @@ impl ServerHandshaker { )), }, MessageWrapper::ClientIdentity(client_identity) => { - match std::mem::take(&mut self.state) { + match core::mem::take(&mut self.state) { ServerHandshakerState::ExpectingClientIdentity(key_negotiator) => { self.process_client_identity(client_identity, key_negotiator) .context("Couldn't process client identity message")?; @@ -602,8 +603,8 @@ pub struct AttestationBehavior { signer: Option, } -impl std::fmt::Debug for AttestationBehavior { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl core::fmt::Debug for AttestationBehavior { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { match ( self.contains_peer_attestation(), self.contains_self_attestation(), @@ -732,6 +733,7 @@ pub fn verify_attestation_info( expected_tee_measurement: &[u8], ) -> anyhow::Result<()> { let attestation_info = AttestationInfo::decode(attestation_info_bytes) + .map_err(anyhow::Error::msg) .context("Couldn't decode attestation info Protobuf message")?; // TODO(#1867): Add remote attestation support, use real TEE reports and check that @@ -760,6 +762,7 @@ pub fn serialize_protobuf(message: &M) -> anyhow::Result std::fmt::Result { +impl core::fmt::Debug for MessageWrapper { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { match self { Self::ClientHello(_) => write!(f, "ClientHello"), Self::ServerIdentity(_) => write!(f, "ServerIdentity"), @@ -62,22 +71,17 @@ impl std::fmt::Debug for MessageWrapper { // TODO(#2105): Implement challenge-response in remote attestation. // TODO(#2106): Support various claims in remote attestation. -// TODO(#2294): Remove `bincode` and use manual message serialization. /// Initial message that starts remote attestation handshake. -#[derive(Serialize, Deserialize, Clone, PartialEq)] +#[derive(Clone, PartialEq)] pub struct ClientHello { - /// Message header. - header: u8, /// Random vector sent in messages for preventing replay attacks. pub random: [u8; REPLAY_PROTECTION_ARRAY_LENGTH], } /// Server identity message containing remote attestation information and a public key for /// Diffie-Hellman key negotiation. -#[derive(Serialize, Deserialize, Clone, PartialEq)] +#[derive(Clone, PartialEq)] pub struct ServerIdentity { - /// Message header. - header: u8, /// Remote attestation protocol version. pub version: u8, /// Public key needed to establish a session key. @@ -92,7 +96,6 @@ pub struct ServerIdentity { /// /// /// - #[serde(with = "BigArray")] pub transcript_signature: [u8; SIGNATURE_LENGTH], /// Public key used to sign transcripts. /// @@ -102,7 +105,6 @@ pub struct ServerIdentity { /// Where X and Y are big-endian coordinates of an Elliptic Curve point. /// /// - #[serde(with = "BigArray")] pub signing_public_key: [u8; SIGNING_ALGORITHM_KEY_LENGTH], /// Information used for remote attestation such as a TEE report and a TEE provider's /// certificate. TEE report contains a hash of the `signing_public_key` and `additional_info`. @@ -120,10 +122,8 @@ pub struct ServerIdentity { /// Client identity message containing remote attestation information and a public key for /// Diffie-Hellman key negotiation. -#[derive(Serialize, Deserialize, Clone, PartialEq)] +#[derive(Clone, PartialEq)] pub struct ClientIdentity { - /// Message header. - header: u8, /// Public key needed to establish a session key. pub ephemeral_public_key: [u8; KEY_AGREEMENT_ALGORITHM_KEY_LENGTH], /// Signature of the SHA-256 hash of all previously sent and received messages. @@ -134,7 +134,6 @@ pub struct ClientIdentity { /// /// /// - #[serde(with = "BigArray")] pub transcript_signature: [u8; SIGNATURE_LENGTH], /// Public key used to sign transcripts. /// @@ -144,7 +143,6 @@ pub struct ClientIdentity { /// Where X and Y are big-endian coordinates of an Elliptic Curve point. /// /// - #[serde(with = "BigArray")] pub signing_public_key: [u8; SIGNING_ALGORITHM_KEY_LENGTH], /// Information used for remote attestation such as a TEE report and a TEE provider's /// certificate. TEE report contains a hash of the `signing_public_key`. @@ -155,12 +153,9 @@ pub struct ClientIdentity { } /// Message containing data encrypted using a session key. -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq)] pub struct EncryptedData { - /// Message header. - header: u8, /// Random nonce (initialization vector) used for encryption/decryption. - #[serde(with = "BigArray")] pub nonce: [u8; NONCE_LENGTH], /// Data encrypted using the session key. pub data: Vec, @@ -178,36 +173,39 @@ pub trait Deserializable { impl ClientHello { pub fn new(random: [u8; REPLAY_PROTECTION_ARRAY_LENGTH]) -> Self { - Self { - header: CLIENT_HELLO_HEADER, - random, - } + Self { random } + } + + const fn len() -> usize { + MESSAGE_HEADER_LENGTH + REPLAY_PROTECTION_ARRAY_LENGTH } } impl Serializable for ClientHello { fn serialize(&self) -> anyhow::Result> { - bincode::serialize(&self).context("Couldn't serialize client hello message") + let mut result = Vec::with_capacity(ClientHello::len()); + result.put_u8(CLIENT_HELLO_HEADER); + result.put_slice(&self.random); + Ok(result) } } impl Deserializable for ClientHello { fn deserialize(input: &[u8]) -> anyhow::Result { - if input.len() <= MAXIMUM_MESSAGE_SIZE { - let message: Self = - bincode::deserialize(input).context("Couldn't deserialize client hello message")?; - if message.header == CLIENT_HELLO_HEADER { - Ok(message) - } else { - Err(anyhow!("Incorrect client hello message header")) - } - } else { - Err(anyhow!( - "Maximum handshake message size of {} exceeded, found {}", - MAXIMUM_MESSAGE_SIZE, + if input.len() != ClientHello::len() { + bail!( + "Invalid client hello message length: expected {}, found {}", + ClientHello::len(), input.len(), - )) + ); + } + let mut input = input; + if input.get_u8() != CLIENT_HELLO_HEADER { + bail!("Invalid client hello message header"); } + let mut random = [0u8; REPLAY_PROTECTION_ARRAY_LENGTH]; + input.copy_to_slice(&mut random); + Ok(Self { random }) } } @@ -220,7 +218,6 @@ impl ServerIdentity { additional_info: Vec, ) -> Self { Self { - header: SERVER_IDENTITY_HEADER, version: PROTOCOL_VERSION, ephemeral_public_key, random, @@ -238,33 +235,84 @@ impl ServerIdentity { pub fn set_transcript_signature(&mut self, transcript_signature: &[u8; SIGNATURE_LENGTH]) { self.transcript_signature = *transcript_signature; } + + const fn min_len() -> usize { + MESSAGE_HEADER_LENGTH + + PROTOCOL_VERSION_LENGTH + + KEY_AGREEMENT_ALGORITHM_KEY_LENGTH + + REPLAY_PROTECTION_ARRAY_LENGTH + + SIGNATURE_LENGTH + + SIGNING_ALGORITHM_KEY_LENGTH + + 2 * VEC_SIZE_PREFIX_LENGTH + } } impl Serializable for ServerIdentity { fn serialize(&self) -> anyhow::Result> { - bincode::serialize(&self).context("Couldn't serialize server identity message") + let mut result = Vec::with_capacity( + ServerIdentity::min_len() + self.attestation_info.len() + self.additional_info.len(), + ); + result.put_u8(SERVER_IDENTITY_HEADER); + result.put_u8(self.version); + result.put_slice(&self.ephemeral_public_key); + result.put_slice(&self.random); + result.put_slice(&self.transcript_signature); + result.put_slice(&self.signing_public_key); + put_vec(&mut result, &self.attestation_info); + put_vec(&mut result, &self.additional_info); + Ok(result) } } impl Deserializable for ServerIdentity { fn deserialize(input: &[u8]) -> anyhow::Result { - if input.len() <= MAXIMUM_MESSAGE_SIZE { - let message: Self = bincode::deserialize(input) - .context("Couldn't deserialize server identity message")?; - if message.header != SERVER_IDENTITY_HEADER { - return Err(anyhow!("Incorrect server identity message header")); - } - if message.version != PROTOCOL_VERSION { - return Err(anyhow!("Incorrect remote attestation protocol version")); - } - Ok(message) - } else { - Err(anyhow!( + if input.len() < ServerIdentity::min_len() { + bail!( + "Server identity message too short: expected at least {} bytes, found {}", + ServerIdentity::min_len(), + input.len(), + ); + } + if input.len() > MAXIMUM_MESSAGE_SIZE { + bail!( "Maximum handshake message size of {} exceeded, found {}", MAXIMUM_MESSAGE_SIZE, input.len(), - )) + ); } + let mut input = input; + if input.get_u8() != SERVER_IDENTITY_HEADER { + bail!("Invalid server identity message header"); + } + + let version = input.get_u8(); + let mut ephemeral_public_key = [0u8; KEY_AGREEMENT_ALGORITHM_KEY_LENGTH]; + input.copy_to_slice(&mut ephemeral_public_key); + let mut random = [0u8; REPLAY_PROTECTION_ARRAY_LENGTH]; + input.copy_to_slice(&mut random); + let mut transcript_signature = [0u8; SIGNATURE_LENGTH]; + input.copy_to_slice(&mut transcript_signature); + let mut signing_public_key = [0u8; SIGNING_ALGORITHM_KEY_LENGTH]; + input.copy_to_slice(&mut signing_public_key); + let attestation_info = get_vec(&mut input)?; + let additional_info = get_vec(&mut input)?; + + if input.has_remaining() { + bail!( + "Invalid server identity message: {} unused bytes detected", + input.remaining() + ); + } + + Ok(Self { + version, + ephemeral_public_key, + random, + transcript_signature, + signing_public_key, + attestation_info, + additional_info, + }) } } @@ -275,7 +323,6 @@ impl ClientIdentity { attestation_info: Vec, ) -> Self { Self { - header: CLIENT_IDENTITY_HEADER, ephemeral_public_key, transcript_signature: [Default::default(); SIGNATURE_LENGTH], signing_public_key, @@ -290,59 +337,120 @@ impl ClientIdentity { pub fn set_transcript_signature(&mut self, transcript_signature: &[u8; SIGNATURE_LENGTH]) { self.transcript_signature = *transcript_signature; } + + const fn min_len() -> usize { + MESSAGE_HEADER_LENGTH + + KEY_AGREEMENT_ALGORITHM_KEY_LENGTH + + SIGNATURE_LENGTH + + SIGNING_ALGORITHM_KEY_LENGTH + + VEC_SIZE_PREFIX_LENGTH + } } impl Serializable for ClientIdentity { fn serialize(&self) -> anyhow::Result> { - bincode::serialize(&self).context("Couldn't serialize client identity message") + let mut result = + Vec::with_capacity(ClientIdentity::min_len() + self.attestation_info.len()); + result.put_u8(CLIENT_IDENTITY_HEADER); + result.put_slice(&self.ephemeral_public_key); + result.put_slice(&self.transcript_signature); + result.put_slice(&self.signing_public_key); + put_vec(&mut result, &self.attestation_info); + Ok(result) } } impl Deserializable for ClientIdentity { fn deserialize(input: &[u8]) -> anyhow::Result { - if input.len() <= MAXIMUM_MESSAGE_SIZE { - let message: Self = bincode::deserialize(input) - .context("Couldn't deserialize client identity message")?; - if message.header == CLIENT_IDENTITY_HEADER { - Ok(message) - } else { - Err(anyhow!("Incorrect client identity message header")) - } - } else { - Err(anyhow!( + if input.len() < ClientIdentity::min_len() { + bail!( + "Client identity message too short: expected at least {} bytes, found {}", + ClientIdentity::min_len(), + input.len(), + ); + } + if input.len() > MAXIMUM_MESSAGE_SIZE { + bail!( "Maximum handshake message size of {} exceeded, found {}", MAXIMUM_MESSAGE_SIZE, input.len(), - )) + ); + } + let mut input = input; + if input.get_u8() != CLIENT_IDENTITY_HEADER { + bail!("Invalid client identity message header"); + } + + let mut ephemeral_public_key = [0u8; KEY_AGREEMENT_ALGORITHM_KEY_LENGTH]; + input.copy_to_slice(&mut ephemeral_public_key); + let mut transcript_signature = [0u8; SIGNATURE_LENGTH]; + input.copy_to_slice(&mut transcript_signature); + let mut signing_public_key = [0u8; SIGNING_ALGORITHM_KEY_LENGTH]; + input.copy_to_slice(&mut signing_public_key); + let attestation_info = get_vec(&mut input)?; + + if input.has_remaining() { + bail!( + "Invalid client identity message: {} unused bytes detected", + input.remaining() + ); } + + Ok(Self { + ephemeral_public_key, + transcript_signature, + signing_public_key, + attestation_info, + }) } } impl EncryptedData { pub fn new(nonce: [u8; NONCE_LENGTH], data: Vec) -> Self { - Self { - header: ENCRYPTED_DATA_HEADER, - nonce, - data, - } + Self { nonce, data } + } + + const fn min_len() -> usize { + MESSAGE_HEADER_LENGTH + NONCE_LENGTH + VEC_SIZE_PREFIX_LENGTH } } impl Serializable for EncryptedData { fn serialize(&self) -> anyhow::Result> { - bincode::serialize(&self).context("Couldn't serialize encrypted data message") + let mut result = Vec::with_capacity(EncryptedData::min_len() + self.data.len()); + result.put_u8(ENCRYPTED_DATA_HEADER); + result.put_slice(&self.nonce); + put_vec(&mut result, &self.data); + Ok(result) } } impl Deserializable for EncryptedData { - fn deserialize(bytes: &[u8]) -> anyhow::Result { - let message: Self = - bincode::deserialize(bytes).context("Couldn't deserialize encrypted data message")?; - if message.header == ENCRYPTED_DATA_HEADER { - Ok(message) - } else { - Err(anyhow!("Incorrect encrypted data message header")) + fn deserialize(input: &[u8]) -> anyhow::Result { + if input.len() < EncryptedData::min_len() { + bail!( + "Encrypted data message too short: expected at least {} bytes, found {}", + EncryptedData::min_len(), + input.len(), + ); + } + let mut input = input; + if input.get_u8() != ENCRYPTED_DATA_HEADER { + bail!("Invalid encrypted data message header"); + } + + let mut nonce = [0u8; NONCE_LENGTH]; + input.copy_to_slice(&mut nonce); + let data = get_vec(&mut input)?; + + if input.has_remaining() { + bail!( + "Invalid encrypted data message: {} unused bytes detected", + input.remaining() + ); } + + Ok(Self { nonce, data }) } } @@ -376,3 +484,20 @@ pub fn deserialize_message(input: &[u8]) -> anyhow::Result { header => Err(anyhow!("Unknown message header: {:#02x}", header)), } } + +fn put_vec(target: &mut Vec, source: &[u8]) { + target.put_u64_le(source.len() as u64); + target.put_slice(source); +} + +fn get_vec(source: &mut &[u8]) -> anyhow::Result> { + let length = source.get_u64_le(); + if length > source.remaining() as u64 { + bail!( + "Invalid vector serialization: required length is {} but only {} bytes provided", + length, + source.remaining() + ) + } + Ok(source.copy_to_bytes(length as usize).to_vec()) +} diff --git a/remote_attestation/rust/src/report.rs b/remote_attestation/rust/src/report.rs index f1269e6a8c2..b2080b5dbb7 100644 --- a/remote_attestation/rust/src/report.rs +++ b/remote_attestation/rust/src/report.rs @@ -23,7 +23,7 @@ impl AttestationReport { /// Placeholder function for collecting TEE measurement of remotely attested TEEs. pub fn new(data: &[u8]) -> Self { Self { - measurement: TEST_TEE_MEASUREMENT.to_string().as_bytes().to_vec(), + measurement: TEST_TEE_MEASUREMENT.as_bytes().to_vec(), data: data.to_vec(), ..Default::default() } diff --git a/remote_attestation/rust/src/tests/crypto.rs b/remote_attestation/rust/src/tests/crypto.rs index e2d58ec7662..75c26b1fa70 100644 --- a/remote_attestation/rust/src/tests/crypto.rs +++ b/remote_attestation/rust/src/tests/crypto.rs @@ -23,6 +23,7 @@ use crate::{ }, message::EncryptedData, }; +use alloc::vec::Vec; use quickcheck_macros::quickcheck; // Keys are only used for test purposes. diff --git a/remote_attestation/rust/src/tests/handshaker.rs b/remote_attestation/rust/src/tests/handshaker.rs index 26161fa4718..2a82e749473 100644 --- a/remote_attestation/rust/src/tests/handshaker.rs +++ b/remote_attestation/rust/src/tests/handshaker.rs @@ -22,6 +22,7 @@ use crate::{ }, tests::message::INVALID_MESSAGE_HEADER, }; +use alloc::{boxed::Box, vec}; use assert_matches::assert_matches; const TEE_MEASUREMENT: &str = "Test TEE measurement"; diff --git a/remote_attestation/rust/src/tests/message.rs b/remote_attestation/rust/src/tests/message.rs index 566ae51bc98..c962aa833f6 100644 --- a/remote_attestation/rust/src/tests/message.rs +++ b/remote_attestation/rust/src/tests/message.rs @@ -25,6 +25,7 @@ use crate::{ MAXIMUM_MESSAGE_SIZE, REPLAY_PROTECTION_ARRAY_LENGTH, SERVER_IDENTITY_HEADER, }, }; +use alloc::{vec, vec::Vec}; use anyhow::{anyhow, Context}; use assert_matches::assert_matches; use quickcheck::{quickcheck, TestResult}; @@ -35,7 +36,7 @@ const INVALID_PROTOCOL_VERSION: u8 = 2; /// Creates a zero initialized array. fn default_array() -> [T; L] where - T: std::marker::Copy + std::default::Default, + T: core::marker::Copy + core::default::Default, { [Default::default(); L] } @@ -43,11 +44,11 @@ where /// Converts slices to arrays (expands with zeroes). fn to_array(input: &[T]) -> anyhow::Result<[T; L]> where - T: std::marker::Copy + std::default::Default, + T: core::marker::Copy + core::default::Default, { if input.len() <= L { // `Default` is only implemented for a limited number of array sizes. - // https://doc.rust-lang.org/std/primitive.array.html#impl-Default + // https://doc.rust-lang.org/core/primitive.array.html#impl-Default let mut result: [T; L] = default_array(); result[..input.len()].copy_from_slice(&input[..input.len()]); Ok(result)