diff --git a/src/vdaf.rs b/src/vdaf.rs index cc2a306fe..5fb568514 100644 --- a/src/vdaf.rs +++ b/src/vdaf.rs @@ -18,6 +18,7 @@ use crate::{ }; use serde::{Deserialize, Serialize}; use std::{fmt::Debug, io::Cursor}; +use subtle::{Choice, ConstantTimeEq}; /// A component of the domain-separation tag, used to bind the VDAF operations to the document /// version. This will be revised with each draft with breaking changes. @@ -57,7 +58,7 @@ pub enum VdafError { } /// An additive share of a vector of field elements. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] pub enum Share { /// An uncompressed share, typically sent to the leader. Leader(Vec), @@ -78,6 +79,26 @@ impl Share { } } +impl PartialEq for Share { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Share {} + +impl ConstantTimeEq for Share { + fn ct_eq(&self, other: &Self) -> subtle::Choice { + // We allow short-circuiting on the type (Leader vs Helper) of the value, but not the types' + // contents. + match (self, other) { + (Share::Leader(self_val), Share::Leader(other_val)) => self_val.ct_eq(other_val), + (Share::Helper(self_val), Share::Helper(other_val)) => self_val.ct_eq(other_val), + _ => Choice::from(0), + } + } +} + /// Parameters needed to decode a [`Share`] #[derive(Clone, Debug, PartialEq, Eq)] pub(crate) enum ShareDecodingParameter { @@ -310,9 +331,23 @@ pub trait Aggregatable: Clone + Debug + From { } /// An output share comprised of a vector of field elements. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] pub struct OutputShare(Vec); +impl PartialEq for OutputShare { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for OutputShare {} + +impl ConstantTimeEq for OutputShare { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + impl AsRef<[F]> for OutputShare { fn as_ref(&self) -> &[F] { &self.0 @@ -339,9 +374,24 @@ impl Encode for OutputShare { /// /// This is suitable for VDAFs where both output shares and aggregate shares are vectors of field /// elements, and output shares need no special transformation to be merged into an aggregate share. -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] + pub struct AggregateShare(Vec); +impl PartialEq for AggregateShare { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for AggregateShare {} + +impl ConstantTimeEq for AggregateShare { + fn ct_eq(&self, other: &Self) -> subtle::Choice { + self.0.ct_eq(&other.0) + } +} + impl AsRef<[F]> for AggregateShare { fn as_ref(&self) -> &[F] { &self.0 @@ -552,6 +602,65 @@ where assert_eq!(encoded, bytes); } +#[cfg(test)] +fn equality_comparison_test(values: &[T]) +where + T: Debug + PartialEq, +{ + use std::ptr; + + // This function expects that every value passed in `values` is distinct, i.e. should not + // compare as equal to any other element. We test both (i, j) and (j, i) to gain confidence that + // equality implementations are symmetric. + for (i, i_val) in values.iter().enumerate() { + for (j, j_val) in values.iter().enumerate() { + if i == j { + assert!(ptr::eq(i_val, j_val)); // sanity + assert_eq!( + i_val, j_val, + "Expected element at index {i} to be equal to itself, but it was not" + ); + } else { + assert_ne!( + i_val, j_val, + "Expected elements at indices {i} & {j} to not be equal, but they were" + ) + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::vdaf::{equality_comparison_test, xof::Seed, AggregateShare, OutputShare, Share}; + + #[test] + fn share_equality_test() { + equality_comparison_test(&[ + Share::Leader(Vec::from([1, 2, 3])), + Share::Leader(Vec::from([3, 2, 1])), + Share::Helper(Seed([1, 2, 3])), + Share::Helper(Seed([3, 2, 1])), + ]) + } + + #[test] + fn output_share_equality_test() { + equality_comparison_test(&[ + OutputShare(Vec::from([1, 2, 3])), + OutputShare(Vec::from([3, 2, 1])), + ]) + } + + #[test] + fn aggregate_share_equality_test() { + equality_comparison_test(&[ + AggregateShare(Vec::from([1, 2, 3])), + AggregateShare(Vec::from([3, 2, 1])), + ]) + } +} + #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] #[cfg_attr( docsrs, diff --git a/src/vdaf/poplar1.rs b/src/vdaf/poplar1.rs index 7f9a277d6..7b09e5f53 100644 --- a/src/vdaf/poplar1.rs +++ b/src/vdaf/poplar1.rs @@ -110,7 +110,7 @@ impl ParameterizedDecode> for P /// /// This is comprised of an IDPF key share and the correlated randomness used to compute the sketch /// during preparation. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone)] pub struct Poplar1InputShare { /// IDPF key share. idpf_key: Seed<16>, @@ -128,6 +128,32 @@ pub struct Poplar1InputShare { corr_leaf: [Field255; 2], } +impl PartialEq for Poplar1InputShare { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Poplar1InputShare {} + +impl ConstantTimeEq for Poplar1InputShare { + fn ct_eq(&self, other: &Self) -> Choice { + // We short-circuit on the length of corr_inner being different. Only the content is + // protected. + if self.corr_inner.len() != other.corr_inner.len() { + return Choice::from(0); + } + + let mut res = self.idpf_key.ct_eq(&other.idpf_key) + & self.corr_seed.ct_eq(&other.corr_seed) + & self.corr_leaf.ct_eq(&other.corr_leaf); + for (x, y) in self.corr_inner.iter().zip(other.corr_inner.iter()) { + res &= x.ct_eq(y); + } + res + } +} + impl Encode for Poplar1InputShare { fn encode(&self, bytes: &mut Vec) { self.idpf_key.encode(bytes); @@ -174,9 +200,23 @@ impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1 bool { + self.ct_eq(other).into() + } +} + +impl Eq for Poplar1PrepareState {} + +impl ConstantTimeEq for Poplar1PrepareState { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + impl Encode for Poplar1PrepareState { fn encode(&self, bytes: &mut Vec) { self.0.encode(bytes) @@ -201,12 +241,31 @@ impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1), Leaf(PrepareState), } +impl PartialEq for PrepareStateVariant { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for PrepareStateVariant {} + +impl ConstantTimeEq for PrepareStateVariant { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the type (Inner vs Leaf). + match (self, other) { + (Self::Inner(self_val), Self::Inner(other_val)) => self_val.ct_eq(other_val), + (Self::Leaf(self_val), Self::Leaf(other_val)) => self_val.ct_eq(other_val), + _ => Choice::from(0), + } + } +} + impl Encode for PrepareStateVariant { fn encode(&self, bytes: &mut Vec) { match self { @@ -252,12 +311,26 @@ impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1 { sketch: SketchState, output_share: Vec, } +impl PartialEq for PrepareState { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for PrepareState {} + +impl ConstantTimeEq for PrepareState { + fn ct_eq(&self, other: &Self) -> Choice { + self.sketch.ct_eq(&other.sketch) & self.output_share.ct_eq(&other.output_share) + } +} + impl Encode for PrepareState { fn encode(&self, bytes: &mut Vec) { self.sketch.encode(bytes); @@ -297,7 +370,7 @@ impl<'a, P, F: FieldElement, const SEED_SIZE: usize> } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] enum SketchState { #[allow(non_snake_case)] RoundOne { @@ -308,6 +381,44 @@ enum SketchState { RoundTwo, } +impl PartialEq for SketchState { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for SketchState {} + +impl ConstantTimeEq for SketchState { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the round (RoundOne vs RoundTwo), as well as is_leader for + // RoundOne comparisons. + match (self, other) { + ( + SketchState::RoundOne { + A_share: self_a_share, + B_share: self_b_share, + is_leader: self_is_leader, + }, + SketchState::RoundOne { + A_share: other_a_share, + B_share: other_b_share, + is_leader: other_is_leader, + }, + ) => { + if self_is_leader != other_is_leader { + return Choice::from(0); + } + + self_a_share.ct_eq(other_a_share) & self_b_share.ct_eq(other_b_share) + } + + (SketchState::RoundTwo, SketchState::RoundTwo) => Choice::from(1), + _ => Choice::from(0), + } + } +} + impl Encode for SketchState { fn encode(&self, bytes: &mut Vec) { match self { @@ -450,7 +561,7 @@ impl ParameterizedDecode for Poplar1PrepareMessage { } /// A vector of field elements transmitted while evaluating Poplar1. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] pub enum Poplar1FieldVec { /// Field type for inner nodes of the IDPF tree. Inner(Vec), @@ -469,6 +580,29 @@ impl Poplar1FieldVec { } } +impl PartialEq for Poplar1FieldVec { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Poplar1FieldVec {} + +impl ConstantTimeEq for Poplar1FieldVec { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the type (Inner vs Leaf). + match (self, other) { + (Poplar1FieldVec::Inner(self_val), Poplar1FieldVec::Inner(other_val)) => { + self_val.ct_eq(other_val) + } + (Poplar1FieldVec::Leaf(self_val), Poplar1FieldVec::Leaf(other_val)) => { + self_val.ct_eq(other_val) + } + _ => Choice::from(0), + } + } +} + impl Encode for Poplar1FieldVec { fn encode(&self, bytes: &mut Vec) { match self { @@ -1368,7 +1502,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::vdaf::run_vdaf_prepare; + use crate::vdaf::{equality_comparison_test, run_vdaf_prepare}; use assert_matches::assert_matches; use rand::prelude::*; use serde::Deserialize; @@ -2129,4 +2263,196 @@ mod tests { fn test_vec_poplar1_3() { check_test_vec(include_str!("test_vec/07/Poplar1_3.json")); } + + #[test] + fn input_share_equality_test() { + equality_comparison_test(&[ + // Default. + Poplar1InputShare { + idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + corr_seed: Seed([16, 17, 18]), + corr_inner: Vec::from([ + [Field64::from(19), Field64::from(20)], + [Field64::from(21), Field64::from(22)], + [Field64::from(23), Field64::from(24)], + ]), + corr_leaf: [Field255::from(25), Field255::from(26)], + }, + // Modified idpf_key. + Poplar1InputShare { + idpf_key: Seed([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]), + corr_seed: Seed([16, 17, 18]), + corr_inner: Vec::from([ + [Field64::from(19), Field64::from(20)], + [Field64::from(21), Field64::from(22)], + [Field64::from(23), Field64::from(24)], + ]), + corr_leaf: [Field255::from(25), Field255::from(26)], + }, + // Modified corr_seed. + Poplar1InputShare { + idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + corr_seed: Seed([18, 17, 16]), + corr_inner: Vec::from([ + [Field64::from(19), Field64::from(20)], + [Field64::from(21), Field64::from(22)], + [Field64::from(23), Field64::from(24)], + ]), + corr_leaf: [Field255::from(25), Field255::from(26)], + }, + // Modified corr_inner. + Poplar1InputShare { + idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + corr_seed: Seed([16, 17, 18]), + corr_inner: Vec::from([ + [Field64::from(24), Field64::from(23)], + [Field64::from(22), Field64::from(21)], + [Field64::from(20), Field64::from(19)], + ]), + corr_leaf: [Field255::from(25), Field255::from(26)], + }, + // Modified corr_leaf. + Poplar1InputShare { + idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + corr_seed: Seed([16, 17, 18]), + corr_inner: Vec::from([ + [Field64::from(19), Field64::from(20)], + [Field64::from(21), Field64::from(22)], + [Field64::from(23), Field64::from(24)], + ]), + corr_leaf: [Field255::from(26), Field255::from(25)], + }, + ]) + } + + #[test] + fn prepare_state_equality_test() { + // This test effectively covers PrepareStateVariant, PrepareState, SketchState as well. + equality_comparison_test(&[ + // Inner, round one. (default) + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(0), + B_share: Field64::from(1), + is_leader: false, + }, + output_share: Vec::from([Field64::from(2), Field64::from(3)]), + })), + // Inner, round one, modified A_share. + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(100), + B_share: Field64::from(1), + is_leader: false, + }, + output_share: Vec::from([Field64::from(2), Field64::from(3)]), + })), + // Inner, round one, modified B_share. + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(0), + B_share: Field64::from(101), + is_leader: false, + }, + output_share: Vec::from([Field64::from(2), Field64::from(3)]), + })), + // Inner, round one, modified is_leader. + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(0), + B_share: Field64::from(1), + is_leader: true, + }, + output_share: Vec::from([Field64::from(2), Field64::from(3)]), + })), + // Inner, round one, modified output_share. + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(0), + B_share: Field64::from(1), + is_leader: false, + }, + output_share: Vec::from([Field64::from(3), Field64::from(2)]), + })), + // Inner, round two. (default) + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([Field64::from(2), Field64::from(3)]), + })), + // Inner, round two, modified output_share. + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([Field64::from(3), Field64::from(2)]), + })), + // Leaf, round one. (default) + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(0), + B_share: Field255::from(1), + is_leader: false, + }, + output_share: Vec::from([Field255::from(2), Field255::from(3)]), + })), + // Leaf, round one, modified A_share. + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(100), + B_share: Field255::from(1), + is_leader: false, + }, + output_share: Vec::from([Field255::from(2), Field255::from(3)]), + })), + // Leaf, round one, modified B_share. + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(0), + B_share: Field255::from(101), + is_leader: false, + }, + output_share: Vec::from([Field255::from(2), Field255::from(3)]), + })), + // Leaf, round one, modified is_leader. + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(0), + B_share: Field255::from(1), + is_leader: true, + }, + output_share: Vec::from([Field255::from(2), Field255::from(3)]), + })), + // Leaf, round one, modified output_share. + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(0), + B_share: Field255::from(1), + is_leader: false, + }, + output_share: Vec::from([Field255::from(3), Field255::from(2)]), + })), + // Leaf, round two. (default) + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([Field255::from(2), Field255::from(3)]), + })), + // Leaf, round two, modified output_share. + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([Field255::from(3), Field255::from(2)]), + })), + ]) + } + + #[test] + fn field_vec_equality_test() { + equality_comparison_test(&[ + // Inner. (default) + Poplar1FieldVec::Inner(Vec::from([Field64::from(0), Field64::from(1)])), + // Inner, modified value. + Poplar1FieldVec::Inner(Vec::from([Field64::from(1), Field64::from(0)])), + // Leaf. (deafult) + Poplar1FieldVec::Leaf(Vec::from([Field255::from(0), Field255::from(1)])), + // Leaf, modified value. + Poplar1FieldVec::Leaf(Vec::from([Field255::from(1), Field255::from(0)])), + ]) + } } diff --git a/src/vdaf/prio2.rs b/src/vdaf/prio2.rs index ff61cf52d..778d87b36 100644 --- a/src/vdaf/prio2.rs +++ b/src/vdaf/prio2.rs @@ -22,6 +22,7 @@ use crate::{ use hmac::{Hmac, Mac}; use sha2::Sha256; use std::{convert::TryFrom, io::Cursor}; +use subtle::{Choice, ConstantTimeEq}; mod client; mod server; @@ -165,9 +166,23 @@ impl Client<16> for Prio2 { } /// State of each [`Aggregator`] during the Preparation phase. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] pub struct Prio2PrepareState(Share); +impl PartialEq for Prio2PrepareState { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Prio2PrepareState {} + +impl ConstantTimeEq for Prio2PrepareState { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + impl Encode for Prio2PrepareState { fn encode(&self, bytes: &mut Vec) { self.0.encode(bytes); @@ -370,7 +385,10 @@ fn role_try_from(agg_id: usize) -> Result { #[cfg(test)] mod tests { use super::*; - use crate::vdaf::{fieldvec_roundtrip_test, prio2::test_vector::Priov2TestVector, run_vdaf}; + use crate::vdaf::{ + equality_comparison_test, fieldvec_roundtrip_test, prio2::test_vector::Priov2TestVector, + run_vdaf, + }; use assert_matches::assert_matches; use rand::prelude::*; @@ -501,4 +519,24 @@ mod tests { assert_eq!(reconstructed, test_vector.reference_sum); } + + #[test] + fn prepare_state_equality_test() { + equality_comparison_test(&[ + Prio2PrepareState(Share::Leader(Vec::from([ + FieldPrio2::from(0), + FieldPrio2::from(1), + ]))), + Prio2PrepareState(Share::Leader(Vec::from([ + FieldPrio2::from(1), + FieldPrio2::from(0), + ]))), + Prio2PrepareState(Share::Helper(Seed( + (0..32).collect::>().try_into().unwrap(), + ))), + Prio2PrepareState(Share::Helper(Seed( + (1..33).collect::>().try_into().unwrap(), + ))), + ]) + } } diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 851689e36..7d0b107e2 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -61,6 +61,7 @@ use std::fmt::Debug; use std::io::Cursor; use std::iter::{self, IntoIterator}; use std::marker::PhantomData; +use subtle::{Choice, ConstantTimeEq}; const DST_MEASUREMENT_SHARE: u16 = 1; const DST_PROOF_SHARE: u16 = 2; @@ -595,7 +596,7 @@ where } /// Message broadcast by the [`Client`] to every [`Aggregator`] during the Sharding phase. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] pub struct Prio3PublicShare { /// Contributions to the joint randomness from every aggregator's share. joint_rand_parts: Option>>, @@ -620,6 +621,24 @@ impl Encode for Prio3PublicShare { } } +impl PartialEq for Prio3PublicShare { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Prio3PublicShare {} + +impl ConstantTimeEq for Prio3PublicShare { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the presence or absence of the joint_rand_parts. + option_ct_eq( + self.joint_rand_parts.as_deref(), + other.joint_rand_parts.as_deref(), + ) + } +} + impl ParameterizedDecode> for Prio3PublicShare where @@ -646,7 +665,7 @@ where } /// Message sent by the [`Client`] to each [`Aggregator`] during the Sharding phase. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] pub struct Prio3InputShare { /// The measurement share. measurement_share: Share, @@ -659,6 +678,25 @@ pub struct Prio3InputShare { joint_rand_blind: Option>, } +impl PartialEq for Prio3InputShare { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Prio3InputShare {} + +impl ConstantTimeEq for Prio3InputShare { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the presence or absence of the joint_rand_blind. + option_ct_eq( + self.joint_rand_blind.as_ref(), + other.joint_rand_blind.as_ref(), + ) & self.measurement_share.ct_eq(&other.measurement_share) + & self.proof_share.ct_eq(&other.proof_share) + } +} + impl Encode for Prio3InputShare { fn encode(&self, bytes: &mut Vec) { if matches!( @@ -726,7 +764,7 @@ where } } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] /// Message broadcast by each [`Aggregator`] in each round of the Preparation phase. pub struct Prio3PrepareShare { /// A share of the FLP verifier message. (See [`Type`].) @@ -736,6 +774,24 @@ pub struct Prio3PrepareShare { joint_rand_part: Option>, } +impl PartialEq for Prio3PrepareShare { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Prio3PrepareShare {} + +impl ConstantTimeEq for Prio3PrepareShare { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the presence or absence of the joint_rand_part. + option_ct_eq( + self.joint_rand_part.as_ref(), + other.joint_rand_part.as_ref(), + ) & self.verifier.ct_eq(&other.verifier) + } +} + impl Encode for Prio3PrepareShare { @@ -783,13 +839,31 @@ impl } } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] /// Result of combining a round of [`Prio3PrepareShare`] messages. pub struct Prio3PrepareMessage { /// The joint randomness seed computed by the Aggregators. joint_rand_seed: Option>, } +impl PartialEq for Prio3PrepareMessage { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Prio3PrepareMessage {} + +impl ConstantTimeEq for Prio3PrepareMessage { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the presnce or absence of the joint_rand_seed. + option_ct_eq( + self.joint_rand_seed.as_ref(), + other.joint_rand_seed.as_ref(), + ) + } +} + impl Encode for Prio3PrepareMessage { fn encode(&self, bytes: &mut Vec) { if let Some(ref seed) = self.joint_rand_seed { @@ -841,7 +915,7 @@ where } /// State of each [`Aggregator`] during the Preparation phase. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] pub struct Prio3PrepareState { measurement_share: Share, joint_rand_seed: Option>, @@ -849,6 +923,29 @@ pub struct Prio3PrepareState { verifier_len: usize, } +impl PartialEq for Prio3PrepareState { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Prio3PrepareState {} + +impl ConstantTimeEq for Prio3PrepareState { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the presence or absence of the joint_rand_seed, as well as + // the aggregator ID & verifier length parameters. + if self.agg_id != other.agg_id || self.verifier_len != other.verifier_len { + return Choice::from(0); + } + + option_ct_eq( + self.joint_rand_seed.as_ref(), + other.joint_rand_seed.as_ref(), + ) & self.measurement_share.ct_eq(&other.measurement_share) + } +} + impl Encode for Prio3PrepareState { @@ -1111,7 +1208,13 @@ where ) -> Result, VdafError> { if self.typ.joint_rand_len() > 0 { // Check that the joint randomness was correct. - if step.joint_rand_seed.as_ref().unwrap() != msg.joint_rand_seed.as_ref().unwrap() { + if step + .joint_rand_seed + .as_ref() + .unwrap() + .ct_ne(msg.joint_rand_seed.as_ref().unwrap()) + .into() + { return Err(VdafError::Uncategorized( "joint randomness mismatch".to_string(), )); @@ -1259,12 +1362,29 @@ where } } +// This function determines equality between two optional, constant-time comparable values. It +// short-circuits on the existence (but not contents) of the values -- a timing side-channel may +// reveal whether the values match on Some or None. +#[inline] +fn option_ct_eq(left: Option<&T>, right: Option<&T>) -> Choice +where + T: ConstantTimeEq + ?Sized, +{ + match (left, right) { + (Some(left), Some(right)) => left.ct_eq(right), + (None, None) => Choice::from(1), + _ => Choice::from(0), + } +} + #[cfg(test)] mod tests { use super::*; #[cfg(feature = "experimental")] use crate::flp::gadgets::ParallelSumGadget; - use crate::vdaf::{fieldvec_roundtrip_test, run_vdaf, run_vdaf_prepare}; + use crate::vdaf::{ + equality_comparison_test, fieldvec_roundtrip_test, run_vdaf, run_vdaf_prepare, + }; use assert_matches::assert_matches; #[cfg(feature = "experimental")] use fixed::{ @@ -1732,4 +1852,147 @@ mod tests { 12, ); } + + #[test] + fn public_share_equality_test() { + equality_comparison_test(&[ + Prio3PublicShare { + joint_rand_parts: Some(Vec::from([Seed([0])])), + }, + Prio3PublicShare { + joint_rand_parts: Some(Vec::from([Seed([1])])), + }, + Prio3PublicShare { + joint_rand_parts: None, + }, + ]) + } + + #[test] + fn input_share_equality_test() { + equality_comparison_test(&[ + // Default. + Prio3InputShare { + measurement_share: Share::Leader(Vec::from([0])), + proof_share: Share::Leader(Vec::from([1])), + joint_rand_blind: Some(Seed([2])), + }, + // Modified measurement share. + Prio3InputShare { + measurement_share: Share::Leader(Vec::from([100])), + proof_share: Share::Leader(Vec::from([1])), + joint_rand_blind: Some(Seed([2])), + }, + // Modified proof share. + Prio3InputShare { + measurement_share: Share::Leader(Vec::from([0])), + proof_share: Share::Leader(Vec::from([101])), + joint_rand_blind: Some(Seed([2])), + }, + // Modified joint_rand_blind. + Prio3InputShare { + measurement_share: Share::Leader(Vec::from([0])), + proof_share: Share::Leader(Vec::from([1])), + joint_rand_blind: Some(Seed([102])), + }, + // Missing joint_rand_blind. + Prio3InputShare { + measurement_share: Share::Leader(Vec::from([0])), + proof_share: Share::Leader(Vec::from([1])), + joint_rand_blind: None, + }, + ]) + } + + #[test] + fn prepare_share_equality_test() { + equality_comparison_test(&[ + // Default. + Prio3PrepareShare { + verifier: Vec::from([0]), + joint_rand_part: Some(Seed([1])), + }, + // Modified verifier. + Prio3PrepareShare { + verifier: Vec::from([100]), + joint_rand_part: Some(Seed([1])), + }, + // Modified joint_rand_part. + Prio3PrepareShare { + verifier: Vec::from([0]), + joint_rand_part: Some(Seed([101])), + }, + // Missing joint_rand_part. + Prio3PrepareShare { + verifier: Vec::from([0]), + joint_rand_part: None, + }, + ]) + } + + #[test] + fn prepare_message_equality_test() { + equality_comparison_test(&[ + // Default. + Prio3PrepareMessage { + joint_rand_seed: Some(Seed([0])), + }, + // Modified joint_rand_seed. + Prio3PrepareMessage { + joint_rand_seed: Some(Seed([100])), + }, + // Missing joint_rand_seed. + Prio3PrepareMessage { + joint_rand_seed: None, + }, + ]) + } + + #[test] + fn prepare_state_equality_test() { + equality_comparison_test(&[ + // Default. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([0])), + joint_rand_seed: Some(Seed([1])), + agg_id: 2, + verifier_len: 3, + }, + // Modified measurement share. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([100])), + joint_rand_seed: Some(Seed([1])), + agg_id: 2, + verifier_len: 3, + }, + // Modified joint_rand_seed. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([0])), + joint_rand_seed: Some(Seed([101])), + agg_id: 2, + verifier_len: 3, + }, + // Missing joint_rand_seed. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([0])), + joint_rand_seed: None, + agg_id: 2, + verifier_len: 3, + }, + // Modified agg_id. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([0])), + joint_rand_seed: Some(Seed([1])), + agg_id: 102, + verifier_len: 3, + }, + // Modified verifier_len. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([0])), + joint_rand_seed: Some(Seed([1])), + agg_id: 2, + verifier_len: 103, + }, + ]) + } } diff --git a/src/vdaf/xof.rs b/src/vdaf/xof.rs index 64af49a01..1456c5588 100644 --- a/src/vdaf/xof.rs +++ b/src/vdaf/xof.rs @@ -38,7 +38,7 @@ use std::{ use subtle::{Choice, ConstantTimeEq}; /// Input of [`Xof`]. -#[derive(Clone, Debug, Eq)] +#[derive(Clone, Debug)] pub struct Seed(pub(crate) [u8; SEED_SIZE]); impl Seed { @@ -61,18 +61,20 @@ impl AsRef<[u8; SEED_SIZE]> for Seed { } } -impl ConstantTimeEq for Seed { - fn ct_eq(&self, other: &Self) -> Choice { - self.0.ct_eq(&other.0) - } -} - impl PartialEq for Seed { fn eq(&self, other: &Self) -> bool { self.ct_eq(other).into() } } +impl Eq for Seed {} + +impl ConstantTimeEq for Seed { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + impl Encode for Seed { fn encode(&self, bytes: &mut Vec) { bytes.extend_from_slice(&self.0[..]); @@ -405,7 +407,7 @@ impl SeedStreamFixedKeyAes128 { #[cfg(test)] mod tests { use super::*; - use crate::field::Field128; + use crate::{field::Field128, vdaf::equality_comparison_test}; use serde::{Deserialize, Serialize}; use std::{convert::TryInto, io::Cursor}; @@ -535,4 +537,9 @@ mod tests { assert_eq!(output_1_trait_api, output_1_alternate_api); assert_eq!(output_2_trait_api, output_2_alternate_api); } + + #[test] + fn seed_equality_test() { + equality_comparison_test(&[Seed([1, 2, 3]), Seed([3, 2, 1])]) + } }