diff --git a/air/Cargo.toml b/air/Cargo.toml index 4d1b0641f..e461ecf57 100644 --- a/air/Cargo.toml +++ b/air/Cargo.toml @@ -25,6 +25,7 @@ fri = { version = "0.9", path = "../fri", package = "winter-fri", default-featur libm = "0.2.8" math = { version = "0.9", path = "../math", package = "winter-math", default-features = false } utils = { version = "0.9", path = "../utils/core", package = "winter-utils", default-features = false } +libc-print = "0.1.23" [dev-dependencies] rand-utils = { version = "0.9", path = "../utils/rand", package = "winter-rand-utils" } diff --git a/air/src/air/logup_gkr/lagrange/transition.rs b/air/src/air/logup_gkr/lagrange/transition.rs index 5f5b110e6..0b7e72344 100644 --- a/air/src/air/logup_gkr/lagrange/transition.rs +++ b/air/src/air/logup_gkr/lagrange/transition.rs @@ -62,6 +62,9 @@ impl LagrangeKernelTransitionConstraints { let c = lagrange_kernel_column_frame; let v = c.num_rows() - 1; let r = lagrange_kernel_rand_elements; + // TODO: avoid reverse() + let mut r = r.to_vec(); + r.reverse(); let k = constraint_idx + 1; (r[v - k] * c[0]) - ((E::ONE - r[v - k]) * c[v - k + 1]) @@ -130,6 +133,9 @@ impl LagrangeKernelTransitionConstraints { let c = lagrange_kernel_column_frame; let v = c.num_rows() - 1; let r = lagrange_kernel_rand_elements; + // TODO: avoid reverse() + let mut r = r.to_vec(); + r.reverse(); for k in 1..v + 1 { transition_evals[k - 1] = (r[v - k] * c[0]) - ((E::ONE - r[v - k]) * c[v - k + 1]); diff --git a/air/src/air/logup_gkr/mod.rs b/air/src/air/logup_gkr/mod.rs index d3e198912..da0e83124 100644 --- a/air/src/air/logup_gkr/mod.rs +++ b/air/src/air/logup_gkr/mod.rs @@ -109,7 +109,7 @@ pub trait LogUpGkrEvaluator: Clone + Sync { /// [1]: https://eprint.iacr.org/2023/1284 fn generate_univariate_iop_for_multi_linear_opening_data( &self, - openings: Vec, + openings: Vec>, eval_point: Vec, public_coin: &mut impl RandomCoin, ) -> GkrData @@ -117,17 +117,25 @@ pub trait LogUpGkrEvaluator: Clone + Sync { E: FieldElement, H: ElementHasher, { + let openings: Vec = + openings[0].clone().chunks(2).flat_map(|ops| [ops[0], ops[1]]).collect(); public_coin.reseed(H::hash_elements(&openings)); + let folding_randomness: E = public_coin.draw().expect("failed to generate randomness"); + let batched_openings: Vec = + openings.chunks(2).map(|p| p[0] + folding_randomness * (p[1] - p[0])).collect(); + let mut batching_randomness = Vec::with_capacity(openings.len() - 1); for _ in 0..openings.len() - 1 { batching_randomness.push(public_coin.draw().expect("failed to generate randomness")) } + let mut eval_point = eval_point; + eval_point.push(folding_randomness); GkrData::new( LagrangeKernelRandElements::new(eval_point), batching_randomness, - openings, + batched_openings, self.get_oracles().to_vec(), ) } @@ -156,7 +164,7 @@ pub trait LogUpGkrEvaluator: Clone + Sync { /// Returns the periodic values used in the LogUp-GKR statement, either as base field element /// during circuit evaluation or as extension field element during the run of sum-check for /// the input layer. - fn build_periodic_values(&self) -> PeriodicTable + fn build_periodic_values(&self, num_rows: usize) -> PeriodicTable where E: FieldElement, { @@ -166,7 +174,7 @@ pub trait LogUpGkrEvaluator: Clone + Sync { .map(|values| values.iter().map(|x| E::from(*x)).collect()) .collect(); - PeriodicTable { table } + PeriodicTable { table, num_rows } } } @@ -264,16 +272,17 @@ pub enum LogUpGkrOracle { #[derive(Clone, Debug, Default, PartialEq, PartialOrd, Eq, Ord)] pub struct PeriodicTable { pub table: Vec>, + pub num_rows: usize, } impl PeriodicTable where E: FieldElement, { - pub fn new(table: Vec>) -> Self { + pub fn new(table: Vec>, num_rows: usize) -> Self { let table = table.iter().map(|col| col.iter().map(|x| E::from(*x)).collect()).collect(); - Self { table } + Self { table, num_rows } } pub fn num_columns(&self) -> usize { @@ -293,13 +302,14 @@ where pub fn bind_least_significant_variable(&mut self, round_challenge: E) { for col in self.table.iter_mut() { - if col.len() > 1 { + if col.len() > 1 && self.num_rows <= col.len() { let num_evals = col.len() >> 1; for i in 0..num_evals { - col[i] = col[i << 1] + round_challenge * (col[(i << 1) + 1] - col[i << 1]); + col[i] = col[i] + round_challenge * (col[i + num_evals] - col[i]); } col.truncate(num_evals) } } + self.num_rows /= 2; } } diff --git a/examples/src/lib.rs b/examples/src/lib.rs index 33f733d7c..bf95e9f8a 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -9,10 +9,10 @@ use winterfell::{ math::fields::f128::BaseElement, FieldExtension, Proof, ProofOptions, VerifierError, }; - pub mod fibonacci; #[cfg(feature = "std")] pub mod lamport; +pub mod logup_gkr; #[cfg(feature = "std")] pub mod merkle; pub mod rescue; @@ -198,6 +198,12 @@ pub enum ExampleType { #[structopt(short = "n", default_value = "3")] num_signers: usize, }, + /// LogUp-GKR + LogUpGkr { + /// Length of the trace; must be a power of two + #[structopt(short = "n", default_value = "65536")] + trace_length: usize, + }, } /// Defines a set of hash functions available for the provided examples. Some examples may not diff --git a/examples/src/logup_gkr/air.rs b/examples/src/logup_gkr/air.rs new file mode 100644 index 000000000..3abc7fa56 --- /dev/null +++ b/examples/src/logup_gkr/air.rs @@ -0,0 +1,172 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::marker::PhantomData; + +use winterfell::{ + math::{fields::f64::BaseElement, ExtensionOf, FieldElement, StarkField}, + Air, AirContext, Assertion, AuxRandElements, EvaluationFrame, LogUpGkrEvaluator, + LogUpGkrOracle, TraceInfo, TransitionConstraintDegree, +}; + +use super::ProofOptions; + +pub const NUM_FRACTIONS: usize = 64; + +pub(crate) struct LogUpGkrSimpleAir { + context: AirContext, +} + +impl Air for LogUpGkrSimpleAir { + type BaseField = BaseElement; + type PublicInputs = (); + + fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { + Self { + context: AirContext::new_multi_segment( + trace_info, + _pub_inputs, + vec![TransitionConstraintDegree::new(1)], + vec![], + 1, + 0, + options, + ), + } + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn evaluate_transition>( + &self, + frame: &EvaluationFrame, + _periodic_values: &[E], + result: &mut [E], + ) { + let current = frame.current()[0]; + let next = frame.next()[0]; + + // increments by 1 + result[0] = next - current - E::ONE; + } + + fn get_assertions(&self) -> Vec> { + vec![Assertion::single(0, 0, BaseElement::ZERO)] + } + + fn evaluate_aux_transition( + &self, + _main_frame: &EvaluationFrame, + _aux_frame: &EvaluationFrame, + _periodic_values: &[F], + _aux_rand_elements: &AuxRandElements, + _result: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + // do nothing + } + + fn get_aux_assertions>( + &self, + _aux_rand_elements: &AuxRandElements, + ) -> Vec> { + vec![] + } + + fn get_logup_gkr_evaluator( + &self, + ) -> impl LogUpGkrEvaluator + { + PlainLogUpGkrEval::new() + } +} + +#[derive(Clone, Default)] +pub struct PlainLogUpGkrEval { + oracles: Vec, + _field: PhantomData, +} + +impl PlainLogUpGkrEval { + pub fn new() -> Self { + let committed_0 = LogUpGkrOracle::CurrentRow(0); + let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_2 = LogUpGkrOracle::CurrentRow(2); + let committed_3 = LogUpGkrOracle::CurrentRow(3); + let committed_4 = LogUpGkrOracle::CurrentRow(4); + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; + Self { oracles, _field: PhantomData } + } +} + +impl LogUpGkrEvaluator for PlainLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + fn get_num_rand_values(&self) -> usize { + 1 + } + + fn get_num_fractions(&self) -> usize { + NUM_FRACTIONS + } + + fn max_degree(&self) -> usize { + 10 + } + + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) + where + E: FieldElement, + { + query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f) + } + + fn evaluate_query( + &self, + query: &[F], + _periodic_values: &[F], + rand_values: &[E], + numerator: &mut [E], + denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + assert_eq!(numerator.len(), self.get_num_fractions()); + assert_eq!(denominator.len(), self.get_num_fractions()); + assert_eq!(query.len(), 5); + + for i in (0..self.get_num_fractions()).step_by(4) { + numerator[i] = E::from(query[1]); + numerator[i + 1] = E::ONE; + numerator[i + 2] = E::ONE; + numerator[i + 3] = E::ONE; + } + + for i in (0..self.get_num_fractions()).step_by(4) { + denominator[i] = rand_values[0] - E::from(query[0]); + denominator[i + 1] = -(rand_values[0] - E::from(query[2])); + denominator[i + 2] = -(rand_values[0] - E::from(query[3])); + denominator[i + 3] = -(rand_values[0] - E::from(query[4])); + } + } + + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + E::ZERO + } +} diff --git a/examples/src/logup_gkr/mod.rs b/examples/src/logup_gkr/mod.rs new file mode 100644 index 000000000..d56b67a98 --- /dev/null +++ b/examples/src/logup_gkr/mod.rs @@ -0,0 +1,132 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use core::marker::PhantomData; + +use winterfell::{ + crypto::{DefaultRandomCoin, ElementHasher, MerkleTree}, + math::fields::f64::BaseElement, + Proof, ProofOptions, Prover, VerifierError, +}; + +use crate::{Example, ExampleOptions, HashFunction}; + +mod air; +use air::LogUpGkrSimpleAir; + +mod prover; +use prover::LogUpGkrSimpleProver; + +#[cfg(test)] +mod tests; + +// CONSTANTS AND TYPES +// ================================================================================================ + +const AUX_TRACE_WIDTH: usize = 2; + +type Blake3_192 = winterfell::crypto::hashers::Blake3_192; +type Blake3_256 = winterfell::crypto::hashers::Blake3_256; +type Sha3_256 = winterfell::crypto::hashers::Sha3_256; +type Rp64_256 = winterfell::crypto::hashers::Rp64_256; +type RpJive64_256 = winterfell::crypto::hashers::RpJive64_256; + +// FIBONACCI EXAMPLE +// ================================================================================================ + +pub fn get_example( + options: &ExampleOptions, + trace_length: usize, +) -> Result, String> { + let (options, hash_fn) = options.to_proof_options(28, 8); + + match hash_fn { + HashFunction::Blake3_192 => Ok(Box::new(LogUpGkrSimple::::new( + trace_length, + AUX_TRACE_WIDTH, + options, + ))), + HashFunction::Blake3_256 => Ok(Box::new(LogUpGkrSimple::::new( + trace_length, + AUX_TRACE_WIDTH, + options, + ))), + HashFunction::Sha3_256 => Ok(Box::new(LogUpGkrSimple::::new( + trace_length, + AUX_TRACE_WIDTH, + options, + ))), + HashFunction::Rp64_256 => Ok(Box::new(LogUpGkrSimple::::new( + trace_length, + AUX_TRACE_WIDTH, + options, + ))), + HashFunction::RpJive64_256 => Ok(Box::new(LogUpGkrSimple::::new( + trace_length, + AUX_TRACE_WIDTH, + options, + ))), + } +} + +#[derive(Clone, Debug)] +struct LogUpGkrSimple> { + trace_len: usize, + aux_segment_width: usize, + options: ProofOptions, + _hasher: PhantomData, +} + +impl> LogUpGkrSimple { + fn new(trace_len: usize, aux_segment_width: usize, options: ProofOptions) -> Self { + assert!(trace_len < u32::MAX.try_into().unwrap()); + + Self { + trace_len, + aux_segment_width, + options, + _hasher: PhantomData, + } + } +} + +// EXAMPLE IMPLEMENTATION +// ================================================================================================ + +impl Example for LogUpGkrSimple +where + H: ElementHasher + Sync + Send, +{ + fn prove(&self) -> Proof { + // create a prover + let prover = LogUpGkrSimpleProver::::new(AUX_TRACE_WIDTH, self.options.clone()); + + let trace = prover.build_trace(self.trace_len, self.aux_segment_width); + + // generate the proof + prover.prove(trace).unwrap() + } + + fn verify(&self, proof: Proof) -> Result<(), VerifierError> { + let acceptable_options = + winterfell::AcceptableOptions::OptionSet(vec![proof.options().clone()]); + + winterfell::verify::, MerkleTree>( + proof, + (), + &acceptable_options, + ) + } + + fn verify_with_wrong_inputs(&self, proof: Proof) -> Result<(), VerifierError> { + let acceptable_options = + winterfell::AcceptableOptions::OptionSet(vec![proof.options().clone()]); + winterfell::verify::, MerkleTree>( + proof, + (), + &acceptable_options, + ) + } +} diff --git a/examples/src/logup_gkr/prover.rs b/examples/src/logup_gkr/prover.rs new file mode 100644 index 000000000..5a4affe31 --- /dev/null +++ b/examples/src/logup_gkr/prover.rs @@ -0,0 +1,169 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use winterfell::{ + crypto::MerkleTree, math::FieldElement, matrix::ColMatrix, Air, AuxRandElements, + ConstraintCompositionCoefficients, DefaultTraceLde, EvaluationFrame, + LogUpGkrConstraintEvaluator, StarkDomain, Trace, TraceInfo, TracePolyTable, +}; + +use super::{ + air::LogUpGkrSimpleAir, BaseElement, DefaultRandomCoin, ElementHasher, PhantomData, + ProofOptions, Prover, +}; + +pub(crate) struct LogUpGkrSimpleProver + Sync + Send> { + aux_trace_width: usize, + options: ProofOptions, + _hasher: PhantomData, +} + +impl + Sync + Send> LogUpGkrSimpleProver { + pub(crate) fn new(aux_trace_width: usize, options: ProofOptions) -> Self { + Self { + aux_trace_width, + options, + _hasher: PhantomData, + } + } + + /// Builds an execution trace for computing a Fibonacci sequence of the specified length such + /// that each row advances the sequence by 2 terms. + pub fn build_trace(&self, trace_len: usize, aux_segment_width: usize) -> LogUpGkrSimpleTrace { + LogUpGkrSimpleTrace::new(trace_len, aux_segment_width) + } +} + +impl + Sync + Send> Prover for LogUpGkrSimpleProver { + type BaseField = BaseElement; + type Air = LogUpGkrSimpleAir; + type Trace = LogUpGkrSimpleTrace; + type HashFn = H; + type VC = MerkleTree; + type RandomCoin = DefaultRandomCoin; + type TraceLde> = + DefaultTraceLde; + type ConstraintEvaluator<'a, E: FieldElement> = + LogUpGkrConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; + + fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { + } + + fn options(&self) -> &ProofOptions { + &self.options + } + + fn new_trace_lde( + &self, + trace_info: &TraceInfo, + main_trace: &ColMatrix, + domain: &StarkDomain, + ) -> (Self::TraceLde, TracePolyTable) + where + E: FieldElement, + { + DefaultTraceLde::new(trace_info, main_trace, domain) + } + + fn new_evaluator<'a, E>( + &self, + air: &'a Self::Air, + aux_rand_elements: Option>, + composition_coefficients: ConstraintCompositionCoefficients, + ) -> Self::ConstraintEvaluator<'a, E> + where + E: FieldElement, + { + LogUpGkrConstraintEvaluator::new(air, aux_rand_elements.unwrap(), composition_coefficients) + } + + fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix + where + E: FieldElement, + { + let main_trace = main_trace.main_segment(); + + let mut columns = Vec::new(); + + let rand_summed = E::from(777_u32); + for _ in 0..self.aux_trace_width { + // building a dummy auxiliary column + let column = main_trace + .get_column(0) + .iter() + .map(|row_val| rand_summed.mul_base(*row_val)) + .collect(); + + columns.push(column); + } + + ColMatrix::new(columns) + } +} + +#[derive(Clone, Debug)] +pub(crate) struct LogUpGkrSimpleTrace { + // dummy main trace + main_trace: ColMatrix, + info: TraceInfo, +} + +impl LogUpGkrSimpleTrace { + fn new(trace_len: usize, aux_segment_width: usize) -> Self { + assert!(trace_len < u32::MAX.try_into().unwrap()); + + let table: Vec = + (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); + let mut multiplicity: Vec = + (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + multiplicity[0] = BaseElement::new(3 * trace_len as u64 - 3 * 4); + multiplicity[1] = BaseElement::new(3 * 4); + + let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..4 { + values_0[i + 4] = BaseElement::ONE; + } + + let mut values_1: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..4 { + values_1[i + 4] = BaseElement::ONE; + } + + let mut values_2: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..4 { + values_2[i + 4] = BaseElement::ONE; + } + + Self { + main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), + info: TraceInfo::new_multi_segment(5, aux_segment_width, 1, trace_len, vec![], true), + } + } + + fn len(&self) -> usize { + self.main_trace.num_rows() + } +} + +impl Trace for LogUpGkrSimpleTrace { + type BaseField = BaseElement; + + fn info(&self) -> &TraceInfo { + &self.info + } + + fn main_segment(&self) -> &ColMatrix { + &self.main_trace + } + + fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame) { + let next_row_idx = row_idx + 1; + self.main_trace.read_row_into(row_idx, frame.current_mut()); + self.main_trace.read_row_into(next_row_idx % self.len(), frame.next_mut()); + } +} diff --git a/examples/src/logup_gkr/tests.rs b/examples/src/logup_gkr/tests.rs new file mode 100644 index 000000000..08a76acda --- /dev/null +++ b/examples/src/logup_gkr/tests.rs @@ -0,0 +1,48 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use winterfell::{FieldExtension, ProofOptions}; + +use super::{Rp64_256, AUX_TRACE_WIDTH}; + +#[test] +fn logup_gkr_small_test_basic_proof_verification() { + let logup_gkr = Box::new(super::LogUpGkrSimple::::new( + 128, + AUX_TRACE_WIDTH, + build_options(false), + )); + crate::tests::test_basic_proof_verification(logup_gkr); +} + +#[test] +fn logup_gkr_small_test_basic_proof_verification_extension() { + let logup_gkr = Box::new(super::LogUpGkrSimple::::new( + 128, + AUX_TRACE_WIDTH, + build_options(true), + )); + crate::tests::test_basic_proof_verification(logup_gkr); +} + +#[ignore = "not relevant"] +#[test] +fn logup_gkr_small_test_basic_proof_verification_fail() { + let logup_gkr = Box::new(super::LogUpGkrSimple::::new( + 128, + AUX_TRACE_WIDTH, + build_options(false), + )); + crate::tests::test_basic_proof_verification_fail(logup_gkr); +} + +fn build_options(use_extension_field: bool) -> ProofOptions { + let extension = if use_extension_field { + FieldExtension::Quadratic + } else { + FieldExtension::None + }; + ProofOptions::new(28, 8, 0, extension, 4, 31) +} diff --git a/examples/src/main.rs b/examples/src/main.rs index 36c1e0c0d..79a7a0e5f 100644 --- a/examples/src/main.rs +++ b/examples/src/main.rs @@ -5,7 +5,7 @@ use std::time::Instant; -use examples::{fibonacci, rescue, vdf, ExampleOptions, ExampleType}; +use examples::{fibonacci, logup_gkr, rescue, vdf, ExampleOptions, ExampleType}; #[cfg(feature = "std")] use examples::{lamport, merkle, rescue_raps}; use structopt::StructOpt; @@ -82,6 +82,7 @@ fn main() { ExampleType::LamportT { num_signers } => { lamport::threshold::get_example(&options, num_signers) }, + ExampleType::LogUpGkr { trace_length } => logup_gkr::get_example(&options, trace_length), } .expect("The example failed to initialize."); diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 37e45c472..86aa1f334 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -19,6 +19,10 @@ bench = false name = "logup_gkr" harness = false +[[bench]] +name = "logup_gkr_e2e" +harness = false + [[bench]] name = "row_matrix" harness = false @@ -29,7 +33,7 @@ harness = false [features] async = ["maybe_async/async"] -concurrent = ["crypto/concurrent", "math/concurrent", "fri/concurrent", "utils/concurrent", "std"] +concurrent = ["crypto/concurrent", "math/concurrent", "fri/concurrent", "utils/concurrent", "sumcheck/concurrent", "std"] default = ["std"] std = ["air/std", "crypto/std", "fri/std", "math/std", "utils/std"] @@ -43,6 +47,7 @@ sumcheck = { version = "0.1", path = "../sumcheck", package = "winter-sumcheck", thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false } tracing = { version = "0.1", default-features = false, features = ["attributes"]} utils = { version = "0.9", path = "../utils/core", package = "winter-utils", default-features = false } +libc-print = "0.1.23" [dev-dependencies] criterion = "0.5" diff --git a/prover/benches/logup_gkr.rs b/prover/benches/logup_gkr.rs index 6e67eddc2..f5673c8aa 100644 --- a/prover/benches/logup_gkr.rs +++ b/prover/benches/logup_gkr.rs @@ -6,24 +6,22 @@ use std::{marker::PhantomData, time::Duration, vec::Vec}; use air::{ - Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients, - EvaluationFrame, FieldExtension, LogUpGkrEvaluator, LogUpGkrOracle, ProofOptions, TraceInfo, - TransitionConstraintDegree, + Air, AirContext, Assertion, AuxRandElements, EvaluationFrame, LogUpGkrEvaluator, + LogUpGkrOracle, ProofOptions, TraceInfo, TransitionConstraintDegree, }; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; -use crypto::MerkleTree; +use crypto::RandomCoin; use math::StarkField; use winter_prover::{ crypto::{hashers::Blake3_256, DefaultRandomCoin}, math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, matrix::ColMatrix, - DefaultTraceLde, LogUpGkrConstraintEvaluator, Prover, StarkDomain, Trace, TracePolyTable, + prove_gkr, Trace, }; -const TRACE_LENS: [usize; 2] = [2_usize.pow(18), 2_usize.pow(20)]; -const AUX_TRACE_WIDTH: usize = 2; +const TRACE_LENS: [usize; 4] = [2_usize.pow(18), 2_usize.pow(19), 2_usize.pow(20), 2_usize.pow(21)]; -/// Simple end-to-end benchmark for LogUp-GKR. +/// Simple benchmark for the GKR part of STARK with LogUp-GKR. /// /// The main trace contains `5` columns and the LogUp relation is a simple one where we have: /// @@ -33,27 +31,31 @@ const AUX_TRACE_WIDTH: usize = 2; /// /// Given the above, the benchmark then gives an idea about the minimal overhead due to enabling /// LogUp-GKR. The overhead could be bigger depending on the complexity of the LogUp relation. -fn prove_with_lagrange_kernel(c: &mut Criterion) { - let mut group = c.benchmark_group("prove with Lagrange kernel column"); +fn prove_with_logup_gkr(c: &mut Criterion) { + let mut group = c.benchmark_group("prove LogUp-GKR"); group.sample_size(10); group.measurement_time(Duration::from_secs(20)); for &trace_len in TRACE_LENS.iter() { group.bench_function(BenchmarkId::new("", trace_len), |b| { - let trace = LogUpGkrSimpleTrace::new(trace_len, AUX_TRACE_WIDTH); - let prover = LogUpGkrSimpleProver::new(AUX_TRACE_WIDTH); + let main_trace = LogUpGkrSimpleTrace::new(trace_len); + let evaluator = PlainLogUpGkrEval::new(); b.iter_batched( - || trace.clone(), - |trace| prover.prove(trace).unwrap(), + || (main_trace.clone(), evaluator.clone()), + |(main_trace, evaluator)| { + let mut public_coin = + DefaultRandomCoin::>::new(&[BaseElement::ZERO; 4]); + prove_gkr::(&main_trace, &evaluator, &mut public_coin) + }, BatchSize::SmallInput, ) }); } } -criterion_group!(lagrange_kernel_group, prove_with_lagrange_kernel); -criterion_main!(lagrange_kernel_group); +criterion_group!(logup_gkr_group, prove_with_logup_gkr); +criterion_main!(logup_gkr_group); // LogUpGkrSimple // ================================================================================================= @@ -66,7 +68,7 @@ struct LogUpGkrSimpleTrace { } impl LogUpGkrSimpleTrace { - fn new(trace_len: usize, aux_segment_width: usize) -> Self { + fn new(trace_len: usize) -> Self { assert!(trace_len < u32::MAX.try_into().unwrap()); // we create a column for the table we are looking values into. These are just the integers @@ -103,7 +105,7 @@ impl LogUpGkrSimpleTrace { Self { main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), - info: TraceInfo::new_multi_segment(5, aux_segment_width, 0, trace_len, vec![], true), + info: TraceInfo::new_multi_segment(5, 0, 0, trace_len, vec![], true), } } @@ -237,11 +239,11 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { } fn get_num_fractions(&self) -> usize { - 4 + 16 } fn max_degree(&self) -> usize { - 3 + 10 } fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) @@ -262,18 +264,42 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { F: FieldElement, E: FieldElement + ExtensionOf, { - assert_eq!(numerator.len(), 4); - assert_eq!(denominator.len(), 4); + assert_eq!(numerator.len(), 16); + assert_eq!(denominator.len(), 16); assert_eq!(query.len(), 5); numerator[0] = E::from(query[1]); numerator[1] = E::ONE; numerator[2] = E::ONE; numerator[3] = E::ONE; + numerator[4] = E::from(query[1]); + numerator[5] = E::ONE; + numerator[6] = E::ONE; + numerator[7] = E::ONE; + numerator[8] = E::from(query[1]); + numerator[9] = E::ONE; + numerator[10] = E::ONE; + numerator[11] = E::ONE; + numerator[12] = E::from(query[1]); + numerator[13] = E::ONE; + numerator[14] = E::ONE; + numerator[15] = E::ONE; denominator[0] = rand_values[0] - E::from(query[0]); denominator[1] = -(rand_values[0] - E::from(query[2])); denominator[2] = -(rand_values[0] - E::from(query[3])); denominator[3] = -(rand_values[0] - E::from(query[4])); + denominator[4] = rand_values[0] - E::from(query[0]); + denominator[5] = -(rand_values[0] - E::from(query[2])); + denominator[6] = -(rand_values[0] - E::from(query[3])); + denominator[7] = -(rand_values[0] - E::from(query[4])); + denominator[8] = rand_values[0] - E::from(query[0]); + denominator[9] = -(rand_values[0] - E::from(query[2])); + denominator[10] = -(rand_values[0] - E::from(query[3])); + denominator[11] = -(rand_values[0] - E::from(query[4])); + denominator[12] = rand_values[0] - E::from(query[0]); + denominator[13] = -(rand_values[0] - E::from(query[2])); + denominator[14] = -(rand_values[0] - E::from(query[3])); + denominator[15] = -(rand_values[0] - E::from(query[4])); } fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E @@ -283,86 +309,3 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { E::ZERO } } -// Prover -// ================================================================================================ - -struct LogUpGkrSimpleProver { - aux_trace_width: usize, - options: ProofOptions, -} - -impl LogUpGkrSimpleProver { - fn new(aux_trace_width: usize) -> Self { - Self { - aux_trace_width, - options: ProofOptions::new(1, 8, 0, FieldExtension::Quadratic, 2, 1), - } - } -} - -impl Prover for LogUpGkrSimpleProver { - type BaseField = BaseElement; - type Air = LogUpGkrSimpleAir; - type Trace = LogUpGkrSimpleTrace; - type HashFn = Blake3_256; - type VC = MerkleTree>; - type RandomCoin = DefaultRandomCoin; - type TraceLde> = - DefaultTraceLde; - type ConstraintEvaluator<'a, E: FieldElement> = - LogUpGkrConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; - - fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { - } - - fn options(&self) -> &ProofOptions { - &self.options - } - - fn new_trace_lde( - &self, - trace_info: &TraceInfo, - main_trace: &ColMatrix, - domain: &StarkDomain, - ) -> (Self::TraceLde, TracePolyTable) - where - E: math::FieldElement, - { - DefaultTraceLde::new(trace_info, main_trace, domain) - } - - fn new_evaluator<'a, E>( - &self, - air: &'a Self::Air, - aux_rand_elements: Option>, - composition_coefficients: ConstraintCompositionCoefficients, - ) -> Self::ConstraintEvaluator<'a, E> - where - E: math::FieldElement, - { - LogUpGkrConstraintEvaluator::new(air, aux_rand_elements.unwrap(), composition_coefficients) - } - - fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix - where - E: FieldElement, - { - let main_trace = main_trace.main_segment(); - - let mut columns = Vec::new(); - - let rand_summed = E::from(777_u32); - for _ in 0..self.aux_trace_width { - // building a dummy auxiliary column - let column = main_trace - .get_column(0) - .iter() - .map(|row_val| rand_summed.mul_base(*row_val)) - .collect(); - - columns.push(column); - } - - ColMatrix::new(columns) - } -} diff --git a/prover/benches/logup_gkr_e2e.rs b/prover/benches/logup_gkr_e2e.rs new file mode 100644 index 000000000..c15630085 --- /dev/null +++ b/prover/benches/logup_gkr_e2e.rs @@ -0,0 +1,392 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::{marker::PhantomData, time::Duration, vec::Vec}; + +use air::{ + Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients, + EvaluationFrame, FieldExtension, LogUpGkrEvaluator, LogUpGkrOracle, ProofOptions, TraceInfo, + TransitionConstraintDegree, +}; +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use crypto::MerkleTree; +use math::StarkField; +use winter_prover::{ + crypto::{hashers::Blake3_256, DefaultRandomCoin}, + math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, + matrix::ColMatrix, + DefaultTraceLde, LogUpGkrConstraintEvaluator, Prover, StarkDomain, Trace, TracePolyTable, +}; + +const TRACE_LENS: [usize; 2] = [2_usize.pow(18), 2_usize.pow(20)]; +const AUX_TRACE_WIDTH: usize = 2; + +/// Simple end-to-end benchmark for LogUp-GKR. +/// +/// The main trace contains `5` columns and the LogUp relation is a simple one where we have: +/// +/// 1. a table of values from `0` to `trace_len - 1`. +/// 2. a multiplicity column containing the number of look ups for each value in the table. +/// 3. three columns with values contained in the table above. +/// +/// Given the above, the benchmark then gives an idea about the minimal overhead due to enabling +/// LogUp-GKR. The overhead could be bigger depending on the complexity of the LogUp relation. +fn prove_with_logup_gkr(c: &mut Criterion) { + let mut group = c.benchmark_group("prove with LogUp-GKR"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(20)); + + for &trace_len in TRACE_LENS.iter() { + group.bench_function(BenchmarkId::new("", trace_len), |b| { + let trace = LogUpGkrSimpleTrace::new(trace_len, AUX_TRACE_WIDTH); + let prover = LogUpGkrSimpleProver::new(AUX_TRACE_WIDTH); + + b.iter_batched( + || trace.clone(), + |trace| prover.prove(trace).unwrap(), + BatchSize::SmallInput, + ) + }); + } +} + +criterion_group!(lagrange_kernel_group, prove_with_logup_gkr); +criterion_main!(lagrange_kernel_group); + +// LogUpGkrSimple +// ================================================================================================= + +#[derive(Clone, Debug)] +struct LogUpGkrSimpleTrace { + // dummy main trace + main_trace: ColMatrix, + info: TraceInfo, +} + +impl LogUpGkrSimpleTrace { + fn new(trace_len: usize, aux_segment_width: usize) -> Self { + assert!(trace_len < u32::MAX.try_into().unwrap()); + + // we create a column for the table we are looking values into. These are just the integers + // from 0 to `trace_len`. + let table: Vec = + (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); + + // we create three columns that contains values contained in `table`. For simplicity, we + // look up only the values `0` or `1`, we look up the value `1` four times and the value `0` + // `trace_len - 4` times. + + let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_0[i + 4] = BaseElement::ONE; + } + + let mut values_1: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_1[i + 4] = BaseElement::ONE; + } + + let mut values_2: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_2[i + 4] = BaseElement::ONE; + } + + // we create the multiplicity column + let mut multiplicity: Vec = + (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + // we look up the value `1` four times in three columns + multiplicity[1] = BaseElement::new(3 * 4); + // we look up the value `0` `trace_len - 4` in three columns + multiplicity[0] = BaseElement::new(3 * trace_len as u64 - 3 * 4); + + Self { + main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), + info: TraceInfo::new_multi_segment(5, aux_segment_width, 0, trace_len, vec![], true), + } + } + + fn len(&self) -> usize { + self.main_trace.num_rows() + } +} + +impl Trace for LogUpGkrSimpleTrace { + type BaseField = BaseElement; + + fn info(&self) -> &TraceInfo { + &self.info + } + + fn main_segment(&self) -> &ColMatrix { + &self.main_trace + } + + fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame) { + let next_row_idx = row_idx + 1; + self.main_trace.read_row_into(row_idx, frame.current_mut()); + self.main_trace.read_row_into(next_row_idx % self.len(), frame.next_mut()); + } +} + +// AIR +// ================================================================================================= + +struct LogUpGkrSimpleAir { + context: AirContext, +} + +impl Air for LogUpGkrSimpleAir { + type BaseField = BaseElement; + type PublicInputs = (); + + fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { + Self { + context: AirContext::new_multi_segment( + trace_info, + _pub_inputs, + vec![TransitionConstraintDegree::new(1)], + vec![], + 1, + 0, + options, + ), + } + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn evaluate_transition>( + &self, + frame: &EvaluationFrame, + _periodic_values: &[E], + result: &mut [E], + ) { + let current = frame.current()[0]; + let next = frame.next()[0]; + + // increments by 1 + result[0] = next - current - E::ONE; + } + + fn get_assertions(&self) -> Vec> { + vec![Assertion::single(0, 0, BaseElement::ZERO)] + } + + fn evaluate_aux_transition( + &self, + _main_frame: &EvaluationFrame, + _aux_frame: &EvaluationFrame, + _periodic_values: &[F], + _aux_rand_elements: &AuxRandElements, + _result: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + // do nothing + } + + fn get_aux_assertions>( + &self, + _aux_rand_elements: &AuxRandElements, + ) -> Vec> { + vec![] + } + + fn get_logup_gkr_evaluator( + &self, + ) -> impl LogUpGkrEvaluator + { + PlainLogUpGkrEval::new() + } +} + +#[derive(Clone, Default)] +pub struct PlainLogUpGkrEval { + oracles: Vec, + _field: PhantomData, +} + +impl PlainLogUpGkrEval { + pub fn new() -> Self { + let committed_0 = LogUpGkrOracle::CurrentRow(0); + let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_2 = LogUpGkrOracle::CurrentRow(2); + let committed_3 = LogUpGkrOracle::CurrentRow(3); + let committed_4 = LogUpGkrOracle::CurrentRow(4); + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; + Self { oracles, _field: PhantomData } + } +} + +impl LogUpGkrEvaluator for PlainLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + fn get_num_rand_values(&self) -> usize { + 1 + } + + fn get_num_fractions(&self) -> usize { + 16 + } + + fn max_degree(&self) -> usize { + 10 + } + + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) + where + E: FieldElement, + { + query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f) + } + + fn evaluate_query( + &self, + query: &[F], + _periodic_values: &[F], + rand_values: &[E], + numerator: &mut [E], + denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + assert_eq!(numerator.len(), 16); + assert_eq!(denominator.len(), 16); + assert_eq!(query.len(), 5); + numerator[0] = E::from(query[1]); + numerator[1] = E::ONE; + numerator[2] = E::ONE; + numerator[3] = E::ONE; + numerator[4] = E::from(query[1]); + numerator[5] = E::ONE; + numerator[6] = E::ONE; + numerator[7] = E::ONE; + numerator[8] = E::from(query[1]); + numerator[9] = E::ONE; + numerator[10] = E::ONE; + numerator[11] = E::ONE; + numerator[12] = E::from(query[1]); + numerator[13] = E::ONE; + numerator[14] = E::ONE; + numerator[15] = E::ONE; + + denominator[0] = rand_values[0] - E::from(query[0]); + denominator[1] = -(rand_values[0] - E::from(query[2])); + denominator[2] = -(rand_values[0] - E::from(query[3])); + denominator[3] = -(rand_values[0] - E::from(query[4])); + denominator[4] = rand_values[0] - E::from(query[0]); + denominator[5] = -(rand_values[0] - E::from(query[2])); + denominator[6] = -(rand_values[0] - E::from(query[3])); + denominator[7] = -(rand_values[0] - E::from(query[4])); + denominator[8] = rand_values[0] - E::from(query[0]); + denominator[9] = -(rand_values[0] - E::from(query[2])); + denominator[10] = -(rand_values[0] - E::from(query[3])); + denominator[11] = -(rand_values[0] - E::from(query[4])); + denominator[12] = rand_values[0] - E::from(query[0]); + denominator[13] = -(rand_values[0] - E::from(query[2])); + denominator[14] = -(rand_values[0] - E::from(query[3])); + denominator[15] = -(rand_values[0] - E::from(query[4])); + } + + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + E::ZERO + } +} +// Prover +// ================================================================================================ + +struct LogUpGkrSimpleProver { + aux_trace_width: usize, + options: ProofOptions, +} + +impl LogUpGkrSimpleProver { + fn new(aux_trace_width: usize) -> Self { + Self { + aux_trace_width, + options: ProofOptions::new(1, 8, 0, FieldExtension::Quadratic, 2, 1), + } + } +} + +impl Prover for LogUpGkrSimpleProver { + type BaseField = BaseElement; + type Air = LogUpGkrSimpleAir; + type Trace = LogUpGkrSimpleTrace; + type HashFn = Blake3_256; + type VC = MerkleTree>; + type RandomCoin = DefaultRandomCoin; + type TraceLde> = + DefaultTraceLde; + type ConstraintEvaluator<'a, E: FieldElement> = + LogUpGkrConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; + + fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { + } + + fn options(&self) -> &ProofOptions { + &self.options + } + + fn new_trace_lde( + &self, + trace_info: &TraceInfo, + main_trace: &ColMatrix, + domain: &StarkDomain, + ) -> (Self::TraceLde, TracePolyTable) + where + E: math::FieldElement, + { + DefaultTraceLde::new(trace_info, main_trace, domain) + } + + fn new_evaluator<'a, E>( + &self, + air: &'a Self::Air, + aux_rand_elements: Option>, + composition_coefficients: ConstraintCompositionCoefficients, + ) -> Self::ConstraintEvaluator<'a, E> + where + E: math::FieldElement, + { + LogUpGkrConstraintEvaluator::new(air, aux_rand_elements.unwrap(), composition_coefficients) + } + + fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix + where + E: FieldElement, + { + let main_trace = main_trace.main_segment(); + + let mut columns = Vec::new(); + + let rand_summed = E::from(777_u32); + for _ in 0..self.aux_trace_width { + // building a dummy auxiliary column + let column = main_trace + .get_column(0) + .iter() + .map(|row_val| rand_summed.mul_base(*row_val)) + .collect(); + + columns.push(column); + } + + ColMatrix::new(columns) + } +} diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 9bcc566b6..d9da99dd5 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -88,7 +88,7 @@ pub use trace::{ }; mod logup_gkr; -use logup_gkr::{build_lagrange_column, build_s_column, prove_gkr}; +pub use logup_gkr::{build_lagrange_column, build_s_column, prove_gkr}; mod channel; use channel::ProverChannel; diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 016cd5218..5081bf016 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -1,21 +1,20 @@ use alloc::vec::Vec; -use core::ops::Add; use air::{EvaluationFrame, GkrData, LogUpGkrEvaluator}; use math::FieldElement; -use sumcheck::{EqFunction, MultiLinearPoly, SumCheckProverError}; +use sumcheck::{CircuitLayerPolys, CircuitWire, EqFunction, SumCheckProverError}; use tracing::instrument; -use utils::{ - batch_iter_mut, chunks, uninit_vector, ByteReader, ByteWriter, Deserializable, - DeserializationError, Serializable, -}; +use utils::{batch_iter_mut, uninit_vector}; use crate::Trace; mod prover; pub use prover::prove_gkr; #[cfg(feature = "concurrent")] -pub use utils::rayon::{current_num_threads as rayon_num_threads, prelude::*}; +pub use utils::{ + rayon::{current_num_threads as rayon_num_threads, prelude::*}, + {chunks, chunks_mut, iter, iter_mut}, +}; // EVALUATED CIRCUIT // ================================================================================================ @@ -56,7 +55,7 @@ pub use utils::rayon::{current_num_threads as rayon_num_threads, prelude::*}; /// This means that layer ν will be the output layer and will consist of four values /// (p_0[ν - 1], p_1[ν - 1], p_0[ν - 1], p_1[ν - 1]) ∈ 𝔽^ν. pub struct EvaluatedCircuit { - layer_polys: Vec>, + layer_polys: Vec>>, } impl EvaluatedCircuit { @@ -70,12 +69,15 @@ impl EvaluatedCircuit { ) -> Result { let mut layer_polys = Vec::new(); - let mut current_layer = + let input_layer = Self::generate_input_layer(main_trace_columns, evaluator, log_up_randomness); - while current_layer.num_wires() > 1 { + + let mut current_layer = + Self::generate_second_layer(input_layer, evaluator.get_num_fractions()); + while current_layer[0].len() > 1 { let next_layer = Self::compute_next_layer(¤t_layer); - layer_polys.push(CircuitLayerPolys::from_circuit_layer(current_layer)); + layer_polys.push(CircuitLayerPolys::from_circuit_layer(¤t_layer)); current_layer = next_layer; } @@ -88,21 +90,25 @@ impl EvaluatedCircuit { /// Note that the return type is a slice of [`CircuitLayerPolys`] as opposed to /// [`CircuitLayer`], since the evaluated layers are stored in a representation which can be /// proved using GKR. - pub fn layers(self) -> Vec> { + pub fn layers(self) -> Vec>> { self.layer_polys } /// Returns the numerator/denominator polynomials representing the output layer of the circuit. - pub fn output_layer(&self) -> &CircuitLayerPolys { + pub fn output_layers(&self) -> &Vec> { self.layer_polys.last().expect("circuit has at least one layer") } /// Evaluates the output layer at `query`, where the numerators of the output layer are treated /// as evaluations of a multilinear polynomial, and similarly for the denominators. - pub fn evaluate_output_layer(&self, query: E) -> (E, E) { - let CircuitLayerPolys { numerators, denominators } = self.output_layer(); + pub fn evaluate_output_layer(&self, query: E) -> Vec<(E, E)> { + let mut res = Vec::with_capacity(self.output_layers().len()); + for output_layer in self.output_layers().iter() { + let CircuitLayerPolys { numerators, denominators } = output_layer; - (numerators.evaluate(&[query]), denominators.evaluate(&[query])) + res.push((numerators.evaluate(&[query]), denominators.evaluate(&[query]))) + } + res } // HELPERS @@ -114,9 +120,9 @@ impl EvaluatedCircuit { trace: &impl Trace, evaluator: &impl LogUpGkrEvaluator, log_up_randomness: &[E], - ) -> CircuitLayer { + ) -> Vec> { let num_fractions = evaluator.get_num_fractions(); - let periodic_values = evaluator.build_periodic_values(); + let periodic_values = evaluator.build_periodic_values(trace.main_segment().num_rows()); let mut input_layer_wires = unsafe { uninit_vector(trace.main_segment().num_rows() * num_fractions) }; @@ -164,161 +170,85 @@ impl EvaluatedCircuit { } } ); - - CircuitLayer::new(input_layer_wires) + input_layer_wires } /// Computes the subsequent layer of the circuit from a given layer. - fn compute_next_layer(prev_layer: &CircuitLayer) -> CircuitLayer { - let next_layer_wires = chunks!(prev_layer.wires(), 2) - .map(|input_wires| { - let left_input_wire = input_wires[0]; - let right_input_wire = input_wires[1]; - - // output wire - left_input_wire + right_input_wire - }) - .collect(); - - CircuitLayer::new(next_layer_wires) - } -} - -// CIRCUIT LAYER POLYS -// =============================================================================================== - -/// Holds a layer of an [`EvaluatedCircuit`] in a representation amenable to proving circuit -/// evaluation using GKR. -#[derive(Clone, Debug)] -pub struct CircuitLayerPolys { - pub numerators: MultiLinearPoly, - pub denominators: MultiLinearPoly, -} - -impl CircuitLayerPolys -where - E: FieldElement, -{ - pub fn from_circuit_layer(layer: CircuitLayer) -> Self { - Self::from_wires(layer.wires) - } - - pub fn from_wires(wires: Vec>) -> Self { - let mut numerators = Vec::new(); - let mut denominators = Vec::new(); - - for wire in wires { - numerators.push(wire.numerator); - denominators.push(wire.denominator); + fn compute_next_layer(prev_layers: &[Vec>]) -> Vec>> { + let mut next_layers: Vec>> = + vec![unsafe { uninit_vector(prev_layers[0].len() / 2) }; prev_layers.len()]; + + #[cfg(feature = "concurrent")] + if prev_layers[0].len() >= 16 { + next_layers.par_iter_mut().enumerate().for_each(|(circuit_idx, circuit)| { + prev_layers[circuit_idx].chunks(2).enumerate().for_each( + |(row, fractions_at_row)| { + let left = fractions_at_row[0]; + let right = fractions_at_row[1]; + circuit[row] = left + right; + }, + ); + }); + } else { + next_layers.iter_mut().enumerate().for_each(|(circuit_idx, circuit)| { + prev_layers[circuit_idx].chunks(2).enumerate().for_each( + |(row, fractions_at_row)| { + let left = fractions_at_row[0]; + let right = fractions_at_row[1]; + circuit[row] = left + right; + }, + ); + }); } - Self { - numerators: MultiLinearPoly::from_evaluations(numerators), - denominators: MultiLinearPoly::from_evaluations(denominators), - } + #[cfg(not(feature = "concurrent"))] + next_layers.iter_mut().enumerate().for_each(|(circuit_idx, circuit)| { + prev_layers[circuit_idx] + .chunks(2) + .enumerate() + .for_each(|(row, fractions_at_row)| { + let left = fractions_at_row[0]; + let right = fractions_at_row[1]; + circuit[row] = left + right; + }); + }); + + next_layers } - fn into_numerators_denominators(self) -> (MultiLinearPoly, MultiLinearPoly) { - (self.numerators, self.denominators) - } -} - -impl Serializable for CircuitLayerPolys -where - E: FieldElement, -{ - fn write_into(&self, target: &mut W) { - let Self { numerators, denominators } = self; - numerators.write_into(target); - denominators.write_into(target); - } -} - -impl Deserializable for CircuitLayerPolys -where - E: FieldElement, -{ - fn read_from(source: &mut R) -> Result { - Ok(Self { - numerators: MultiLinearPoly::read_from(source)?, - denominators: MultiLinearPoly::read_from(source)?, - }) - } -} - -// CIRCUIT LAYER -// =============================================================================================== - -/// Represents a layer in a [`EvaluatedCircuit`]. -/// -/// A layer is made up of a set of `n` wires, where `n` is a power of two. This is the natural -/// circuit representation of a layer, where each consecutive pair of wires are summed to yield a -/// wire in the subsequent layer of an [`EvaluatedCircuit`]. -/// -/// Note that a [`Layer`] needs to be first converted to a [`LayerPolys`] before the evaluation of -/// the layer can be proved using GKR. -pub struct CircuitLayer { - wires: Vec>, -} - -impl CircuitLayer { - /// Creates a new [`Layer`] from a set of projective coordinates. - /// - /// Panics if the number of projective coordinates is not a power of two. - pub fn new(wires: Vec>) -> Self { - assert!(wires.len().is_power_of_two()); - - Self { wires } - } - - /// Returns the wires that make up this circuit layer. - pub fn wires(&self) -> &[CircuitWire] { - &self.wires - } - - /// Returns the number of wires in the layer. - pub fn num_wires(&self) -> usize { - self.wires.len() - } -} - -// CIRCUIT WIRE -// =============================================================================================== - -/// Represents a fraction `numerator / denominator` as a pair `(numerator, denominator)`. This is -/// the type for the gates' inputs in [`prover::EvaluatedCircuit`]. -/// -/// Hence, addition is defined in the natural way fractions are added together: `a/b + c/d = (ad + -/// bc) / bd`. -#[derive(Debug, Clone, Copy)] -pub struct CircuitWire { - numerator: E, - denominator: E, -} - -impl CircuitWire -where - E: FieldElement, -{ - /// Creates new projective coordinates from a numerator and a denominator. - pub fn new(numerator: E, denominator: E) -> Self { - assert_ne!(denominator, E::ZERO); - - Self { numerator, denominator } - } -} - -impl Add for CircuitWire -where - E: FieldElement, -{ - type Output = Self; - - fn add(self, other: Self) -> Self { - let numerator = self.numerator * other.denominator + other.numerator * self.denominator; - let denominator = self.denominator * other.denominator; - - Self::new(numerator, denominator) + fn generate_second_layer( + current_layer: Vec>, + num_fractions: usize, + ) -> Vec>> { + let mut result: Vec>> = + vec![ + unsafe { uninit_vector(current_layer.len() / (num_fractions * 2)) }; + num_fractions + ]; + + #[cfg(feature = "concurrent")] + result.par_iter_mut().enumerate().for_each(|(circuit_idx, circuit)| { + current_layer.chunks(2 * num_fractions).enumerate().for_each( + |(row, fractions_at_row)| { + let left = fractions_at_row[circuit_idx]; + let right = fractions_at_row[circuit_idx + num_fractions]; + circuit[row] = left + right; + }, + ); + }); + + #[cfg(not(feature = "concurrent"))] + result.iter_mut().enumerate().for_each(|(circuit_idx, circuit)| { + current_layer.chunks(2 * num_fractions).enumerate().for_each( + |(row, fractions_at_row)| { + let left = fractions_at_row[circuit_idx]; + let right = fractions_at_row[circuit_idx + num_fractions]; + circuit[row] = left + right; + }, + ); + }); + + result } } @@ -326,47 +256,7 @@ where #[derive(Debug)] pub struct GkrClaim { pub evaluation_point: Vec, - pub claimed_evaluation: (E, E), -} - -/// We receive our 4 multilinear polynomials which were evaluated at a random point: -/// `left_numerators` (or `p0`), `right_numerators` (or `p1`), `left_denominators` (or `q0`), and -/// `right_denominators` (or `q1`). We'll call the 4 evaluations at a random point `p0(r)`, `p1(r)`, -/// `q0(r)`, and `q1(r)`, respectively, where `r` is the random point. Note that `r` is a shorthand -/// for a tuple of random values `(r_0, ... r_{l-1})`, where `2^{l + 1}` is the number of wires in -/// the layer. -/// -/// It is important to recall how `p0` and `p1` were constructed (and analogously for `q0` and -/// `q1`). They are the `numerators` layer polynomial (or `p`) evaluations `p(0, r)` and `p(1, r)`, -/// obtained from [`MultiLinearPoly::project_least_significant_variable`]. Hence, `[p0, p1]` form -/// the evaluations of polynomial `p'(x_0) = p(x_0, r)`. Then, the round claim for `numerators`, -/// defined as `p(r_layer, r)`, is simply `p'(r_layer)`. -fn reduce_layer_claim( - left_numerators_opening: E, - right_numerators_opening: E, - left_denominators_opening: E, - right_denominators_opening: E, - r_layer: E, -) -> (E, E) -where - E: FieldElement, -{ - // This is the `numerators` layer polynomial `f(x_0) = numerators(x_0, rx_0, ..., rx_{l-1})`, - // where `rx_0, ..., rx_{l-1}` are the random variables that were sampled during the sumcheck - // round for this layer. - let numerators_univariate = - MultiLinearPoly::from_evaluations(vec![left_numerators_opening, right_numerators_opening]); - - // This is analogous to `numerators_univariate`, but for the `denominators` layer polynomial - let denominators_univariate = MultiLinearPoly::from_evaluations(vec![ - left_denominators_opening, - right_denominators_opening, - ]); - - ( - numerators_univariate.evaluate(&[r_layer]), - denominators_univariate.evaluate(&[r_layer]), - ) + pub claimed_evaluations_per_circuit: Vec<(E, E)>, } /// Builds the auxiliary trace column for the univariate sum-check argument. @@ -380,35 +270,69 @@ where /// product i.e., equation (12) in [1]. This oracle is refered to throughout the codebase as /// the s-column. /// -/// The following function's purpose is two build the column in point 2 given the one in point 1. +/// The following function's purpose is to build the column in point 2 given the one in point 1. /// /// [1]: https://eprint.iacr.org/2023/1284 pub fn build_s_column( - main_trace: &impl Trace, + trace: &impl Trace, gkr_data: &GkrData, evaluator: &impl LogUpGkrEvaluator, lagrange_kernel_col: &[E], ) -> Vec { let c = gkr_data.compute_batched_claim(); - let main_segment = main_trace.main_segment(); - let mean = c / E::from(E::BaseField::from(main_segment.num_rows() as u32)); + let num_oracles = evaluator.get_oracles().len(); - let mut result = Vec::with_capacity(main_segment.num_rows()); - let mut last_value = E::ZERO; - result.push(last_value); + let main_segment = trace.main_segment(); + let num_cols = main_segment.num_cols(); + let num_rows = main_segment.num_rows(); + let mean = c / E::from(E::BaseField::from(num_rows as u32)); - let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; - let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols()); + #[cfg(not(feature = "concurrent"))] + let result = { + let mut result = Vec::with_capacity(num_rows); + let mut last_value = E::ZERO; + result.push(last_value); - for (i, item) in lagrange_kernel_col.iter().enumerate().take(main_segment.num_rows() - 1) { - main_trace.read_main_frame(i, &mut main_frame); + let mut query = vec![E::BaseField::ZERO; num_oracles]; + let mut main_frame = EvaluationFrame::new(num_cols); - evaluator.build_query(&main_frame, &mut query); - let cur_value = last_value - mean + gkr_data.compute_batched_query(&query) * *item; + for (i, item) in lagrange_kernel_col.iter().enumerate().take(num_rows - 1) { + trace.read_main_frame(i, &mut main_frame); - result.push(cur_value); - last_value = cur_value; - } + evaluator.build_query(&main_frame, &mut query); + let cur_value = last_value - mean + gkr_data.compute_batched_query(&query) * *item; + + result.push(cur_value); + last_value = cur_value; + } + + result + }; + + #[cfg(feature = "concurrent")] + let result = { + let mut deltas = unsafe { uninit_vector(num_rows) }; + deltas[0] = E::ZERO; + let batch_size = num_rows / rayon_num_threads().next_power_of_two(); + batch_iter_mut!(&mut deltas[1..], batch_size, |batch: &mut [E], batch_offset: usize| { + let mut query = vec![E::BaseField::ZERO; num_oracles]; + let mut main_frame = EvaluationFrame::::new(num_cols); + + for (i, v) in batch.iter_mut().enumerate() { + trace.read_main_frame(i + batch_offset, &mut main_frame); + + evaluator.build_query(&main_frame, &mut query); + *v = gkr_data.compute_batched_query(&query) * lagrange_kernel_col[i + batch_offset] + - mean; + } + }); + + // note that `deltas[0]` is set `0` and thus `deltas` satisfies the conditions for invoking + // the function + let mut cumulative_sum = deltas; + prefix_sum_parallel(&mut cumulative_sum, batch_size); + cumulative_sum + }; result } @@ -425,3 +349,53 @@ pub enum GkrProverError { #[error("failed to generate the random challenge")] FailedToGenerateChallenge, } + +// HELPER +// ================================================================================================= + +/// Computes the cumulative sum, also called prefix sum, of a vector of field elements using +/// parallelism, in place. +/// +/// The function divides the vector into non-overlapping segments and then computes an array of sums +/// for each segment. The function then applies the naive serial implementation to each segment and +/// uses the pre-computed sums in each segment in order to coordinate the results in the different +/// segments. +/// +/// The input vector is of the form `0 || values` where `values` are the values the cumulative sum +/// vector will be computed for, in place. +#[cfg(feature = "concurrent")] +fn prefix_sum_parallel(vector: &mut [E], batch_size: usize) { + let num_partitions = (vector.len() + batch_size - 1) / batch_size; + let mut sum_per_partition = vec![E::ZERO; num_partitions]; + + chunks!(vector, batch_size) + .zip(iter_mut!(sum_per_partition)) + .for_each(|(chunk, entry)| *entry = chunk.iter().fold(E::ZERO, |acc, term| acc + *term)); + + prefix_sum_truncate_right(&mut sum_per_partition); + + chunks_mut!(vector, batch_size) + .zip(iter!(sum_per_partition)) + .for_each(|(chunk, sum_so_far)| prefix_sum_truncate_left(chunk, *sum_so_far)); +} + +/// Computes the cumulative sum of a vector but omits the final cumulative sum. +#[cfg(feature = "concurrent")] +fn prefix_sum_truncate_right(values: &mut [E]) { + let mut sum = E::ZERO; + values.iter_mut().for_each(|v| { + let tmp = *v; + *v = sum; + sum += tmp; + }); +} + +/// Computes the cumulative sum of a vector but omits the initial cumulative sum, namely zero. +#[cfg(feature = "concurrent")] +fn prefix_sum_truncate_left(values: &mut [E], sum: E) { + let mut sum = sum; + values.iter_mut().for_each(|v| { + sum += *v; + *v = sum; + }); +} diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index f258d0845..c6cd52552 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -4,12 +4,16 @@ use air::{LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; use sumcheck::{ - sum_check_prove_higher_degree, sumcheck_prove_plain, BeforeFinalLayerProof, CircuitOutput, - EqFunction, FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, + sum_check_prove_higher_degree, sumcheck_prove_plain_batched, + sumcheck_prove_plain_batched_serial, BeforeFinalLayerProof, CircuitOutput, EqFunction, + FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, }; use tracing::instrument; +#[cfg(feature = "concurrent")] +use utils::rayon::prelude::*; +use utils::{iter, iter_mut, uninit_vector}; -use super::{reduce_layer_claim, CircuitLayerPolys, EvaluatedCircuit, GkrClaim, GkrProverError}; +use super::{CircuitLayerPolys, EvaluatedCircuit, GkrClaim, GkrProverError}; use crate::{matrix::ColMatrix, Trace}; // PROVER @@ -68,19 +72,20 @@ pub fn prove_gkr( } // evaluate the GKR fractional sum circuit - let circuit = EvaluatedCircuit::new(main_trace, evaluator, &logup_randomness)?; + let circuits = EvaluatedCircuit::new(main_trace, evaluator, &logup_randomness)?; // include the circuit output as part of the final proof - let CircuitLayerPolys { numerators, denominators } = circuit.output_layer().clone(); + let output_layers = circuits.output_layers().clone(); // run the GKR prover for all layers except the input layer - let (before_final_layer_proofs, gkr_claim) = prove_intermediate_layers(circuit, public_coin)?; + let (before_final_layer_proofs, gkr_claim, tensored_circuit_batching_randomness) = + prove_intermediate_layers(circuits, public_coin)?; // build the MLEs of the relevant main trace columns let main_trace_mls = - build_mls_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; + build_mle_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; // build the periodic table representing periodic columns as multi-linear extensions - let periodic_table = evaluator.build_periodic_values(); + let periodic_table = evaluator.build_periodic_values(main_trace.main_segment().num_rows()); // run the GKR prover for the input layer let final_layer_proof = prove_input_layer( @@ -89,11 +94,23 @@ pub fn prove_gkr( main_trace_mls, periodic_table, gkr_claim, + &tensored_circuit_batching_randomness, public_coin, )?; + let mut numerators_all_circuits = vec![]; + let mut denominators_all_circuits = vec![]; + for output_layer in output_layers { + let CircuitLayerPolys { numerators, denominators } = output_layer; + numerators_all_circuits.push(numerators); + denominators_all_circuits.push(denominators); + } + Ok(GkrCircuitProof { - circuit_outputs: CircuitOutput { numerators, denominators }, + circuit_outputs: CircuitOutput { + numerators: numerators_all_circuits, + denominators: denominators_all_circuits, + }, before_final_layer_proofs, final_layer_proof, }) @@ -111,23 +128,37 @@ fn prove_input_layer< multi_linear_ext_polys: Vec>, periodic_table: PeriodicTable, claim: GkrClaim, + tensored_batching_randomness: &[E], transcript: &mut C, ) -> Result, GkrProverError> { // parse the [GkrClaim] resulting from the previous GKR layer - let GkrClaim { evaluation_point, claimed_evaluation } = claim; + let GkrClaim { + evaluation_point, + claimed_evaluations_per_circuit: claimed_evaluations, + } = claim; + + let mut all_claims_concatenated = Vec::with_capacity(claimed_evaluations.len()); + for claimed_evaluation in claimed_evaluations.iter() { + all_claims_concatenated.extend_from_slice(&[claimed_evaluation.0, claimed_evaluation.1]); + } + transcript.reseed(H::hash_elements(&all_claims_concatenated)); - transcript.reseed(H::hash_elements(&[claimed_evaluation.0, claimed_evaluation.1])); let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; - let claim = claimed_evaluation.0 + claimed_evaluation.1 * r_batch; + let mut full_claim = E::ZERO; + for (circuit_idx, claimed_evaluation) in claimed_evaluations.iter().enumerate() { + let claim = claimed_evaluation.0 + claimed_evaluation.1 * r_batch; + full_claim += claim * tensored_batching_randomness[circuit_idx] + } let proof = sum_check_prove_higher_degree( evaluator, evaluation_point, - claim, + full_claim, r_batch, log_up_randomness, multi_linear_ext_polys, periodic_table, + tensored_batching_randomness, transcript, )?; @@ -137,34 +168,40 @@ fn prove_input_layer< /// Builds the multi-linear extension polynomials needed to run the final sum-check of GKR for /// LogUp-GKR. #[instrument(skip_all)] -fn build_mls_from_main_trace_segment( +fn build_mle_from_main_trace_segment( oracles: &[LogUpGkrOracle], main_trace: &ColMatrix<::BaseField>, ) -> Result>, GkrProverError> { - let mut mls = vec![]; + let mut mls = Vec::with_capacity(oracles.len()); for oracle in oracles { match oracle { LogUpGkrOracle::CurrentRow(index) => { let col = main_trace.get_column(*index); - let values: Vec = col.iter().map(|value| E::from(*value)).collect(); + let values: Vec = iter!(col).map(|value| E::from(*value)).collect(); let ml = MultiLinearPoly::from_evaluations(values); mls.push(ml) }, LogUpGkrOracle::NextRow(index) => { let col = main_trace.get_column(*index); - let mut values: Vec = col.iter().map(|value| E::from(*value)).collect(); - values.rotate_left(1); + + let mut values: Vec = unsafe { uninit_vector(col.len()) }; + values[col.len() - 1] = E::from(col[0]); + iter_mut!(&mut values[..col.len() - 1]) + .enumerate() + .for_each(|(i, value)| *value = E::from(col[i + 1])); let ml = MultiLinearPoly::from_evaluations(values); mls.push(ml) }, }; } + Ok(mls) } /// Proves all GKR layers except for input layer. #[instrument(skip_all)] +#[allow(clippy::type_complexity)] fn prove_intermediate_layers< E: FieldElement, C: RandomCoin, @@ -172,18 +209,35 @@ fn prove_intermediate_layers< >( circuit: EvaluatedCircuit, transcript: &mut C, -) -> Result<(BeforeFinalLayerProof, GkrClaim), GkrProverError> { +) -> Result<(BeforeFinalLayerProof, GkrClaim, Vec), GkrProverError> { // absorb the circuit output layer. This corresponds to sending the four values of the output // layer to the verifier. The verifier then replies with a challenge `r` in order to evaluate // `p` and `q` at `r` as multi-linears. - let CircuitLayerPolys { numerators, denominators } = circuit.output_layer(); - let mut evaluations = numerators.evaluations().to_vec(); - evaluations.extend_from_slice(denominators.evaluations()); - transcript.reseed(H::hash_elements(&evaluations)); + let output_layers = circuit.output_layers(); + + let mut total_evaluations = + Vec::with_capacity(output_layers[0].numerators.evaluations().len() * 2); + for output_layer in output_layers.iter() { + total_evaluations.extend_from_slice(output_layer.numerators.evaluations()); + total_evaluations.extend_from_slice(output_layer.denominators.evaluations()); + } + transcript.reseed(H::hash_elements(&total_evaluations)); // generate the challenge and reduce [p0, p1, q0, q1] to [pr, qr] let r = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; - let mut claimed_evaluation = circuit.evaluate_output_layer(r); + let mut claimed_evaluations = circuit.evaluate_output_layer(r); + let num_circuits = claimed_evaluations.len(); + let log_num_circuits = num_circuits.next_power_of_two().ilog2(); + + let mut circuit_batching_randomness: Vec = Vec::with_capacity(log_num_circuits as usize); + for _ in 0..log_num_circuits { + let batching_r = + transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; + circuit_batching_randomness.push(batching_r); + } + + let tensored_circuit_batching_randomness = + EqFunction::new(circuit_batching_randomness.into()).evaluations(); let mut layer_proofs: Vec> = Vec::new(); let mut evaluation_point = vec![r]; @@ -195,44 +249,42 @@ fn prove_intermediate_layers< // loop over all inner layers in order to iteratively reduce a layer in terms of its successor // layer. Note that we don't include the input layer, since its predecessor layer will be // reduced in terms of the input layer separately in `prove_final_circuit_layer`. - for inner_layer in circuit.layers().into_iter().skip(1).rev().skip(1) { + for inner_layer in circuit.layers().into_iter().rev().skip(1) { // construct the Lagrange kernel evaluated at the previous GKR round randomness - let mut eq_mle = EqFunction::ml_at(evaluation_point.into()); - - let (numerators, denominators) = inner_layer.into_numerators_denominators(); + let mut eq_mle = EqFunction::ml_at(evaluation_point.clone().into()); // run the sumcheck protocol let proof = sum_check_prove_num_rounds_degree_3( - claimed_evaluation, - numerators, - denominators, + inner_layer, + &claimed_evaluations, + &evaluation_point, &mut eq_mle, + &tensored_circuit_batching_randomness, transcript, )?; - // sample a random challenge to reduce claims - transcript.reseed(H::hash_elements(&proof.openings_claim.openings)); + // generate the random challenge to reduce two claims into a single claim + let mut total_openings = Vec::with_capacity(proof.openings_claim.openings.len() * 4); + for opening_circuit_i in proof.openings_claim.openings.iter() { + total_openings.extend_from_slice(opening_circuit_i); + } + transcript.reseed(H::hash_elements(&total_openings)); let r_layer = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; - // reduce the claim - claimed_evaluation = { - let left_numerators_opening = proof.openings_claim.openings[0]; - let right_numerators_opening = proof.openings_claim.openings[1]; - let left_denominators_opening = proof.openings_claim.openings[2]; - let right_denominators_opening = proof.openings_claim.openings[3]; - - reduce_layer_claim( - left_numerators_opening, - right_numerators_opening, - left_denominators_opening, - right_denominators_opening, - r_layer, - ) - }; + // reduce the claims + for (j, claimed_opening) in proof.openings_claim.openings.iter().enumerate() { + let p0 = claimed_opening[0]; + let p1 = claimed_opening[1]; + let q0 = claimed_opening[2]; + let q1 = claimed_opening[3]; + + let reduced_claim = (p0 + r_layer * (p1 - p0), q0 + r_layer * (q1 - q0)); + claimed_evaluations[j] = reduced_claim; + } // collect the randomness used for the current layer - let mut ext = vec![r_layer]; - ext.extend_from_slice(&proof.openings_claim.eval_point); + let mut ext = proof.openings_claim.eval_point.clone(); + ext.push(r_layer); evaluation_point = ext; layer_proofs.push(proof); @@ -240,7 +292,11 @@ fn prove_intermediate_layers< Ok(( BeforeFinalLayerProof { proof: layer_proofs }, - GkrClaim { evaluation_point, claimed_evaluation }, + GkrClaim { + evaluation_point, + claimed_evaluations_per_circuit: claimed_evaluations, + }, + tensored_circuit_batching_randomness, )) } @@ -251,18 +307,47 @@ fn sum_check_prove_num_rounds_degree_3< C: RandomCoin, H: ElementHasher, >( - claim: (E, E), - p: MultiLinearPoly, - q: MultiLinearPoly, + inner_layers: Vec>, + claims: &[(E, E)], + evaluation_point: &[E], eq: &mut MultiLinearPoly, + tensored_batching_randomness: &[E], transcript: &mut C, ) -> Result, GkrProverError> { // generate challenge to batch two sumchecks - transcript.reseed(H::hash_elements(&[claim.0, claim.1])); - let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; - let claim = claim.0 + claim.1 * r_batch; + let mut concatenated_claims = Vec::with_capacity(claims.len() * 2); + for claim in claims { + concatenated_claims.extend_from_slice(&[claim.0, claim.1]); + } + transcript.reseed(H::hash_elements(&concatenated_claims)); - let proof = sumcheck_prove_plain(claim, r_batch, p, q, eq, transcript)?; + let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; + let mut batched_claims = vec![]; + for claim in claims { + let claim = claim.0 + claim.1 * r_batch; + batched_claims.push(claim) + } + let proof = if inner_layers[0].numerators.num_evaluations() >= 16 { + sumcheck_prove_plain_batched( + &batched_claims, + evaluation_point, + r_batch, + inner_layers, + eq, + tensored_batching_randomness, + transcript, + )? + } else { + sumcheck_prove_plain_batched_serial( + &batched_claims, + evaluation_point, + r_batch, + inner_layers, + eq, + tensored_batching_randomness, + transcript, + )? + }; Ok(proof) } diff --git a/prover/src/trace/mod.rs b/prover/src/trace/mod.rs index 08fb49a2a..4aeca1177 100644 --- a/prover/src/trace/mod.rs +++ b/prover/src/trace/mod.rs @@ -244,6 +244,9 @@ pub trait Trace: Sized + Sync { let v = trace_length.ilog2() as usize; let gkr_data = aux_rand_elements.gkr_data().expect("should not be None"); let r = gkr_data.lagrange_kernel_rand_elements(); + // TODO: avoid reverse() + let mut r = r.to_vec(); + r.reverse(); // Loop over every Lagrange kernel constraint for constraint_idx in 1..v + 1 { diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 7db2e8058..4563890bc 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -41,7 +41,7 @@ utils = { version = "0.9", path = "../utils/core", package = "winter-utils", def rayon = { version = "1.8", optional = true } smallvec = { version = "1.13", default-features = false } thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false } - +libc-print = "0.1.23" [dev-dependencies] criterion = "0.5" rand-utils = { version = "0.9", path = "../utils/rand", package = "winter-rand-utils" } \ No newline at end of file diff --git a/sumcheck/benches/sum_check_high_degree.rs b/sumcheck/benches/sum_check_high_degree.rs index 483890579..d6bb8815e 100644 --- a/sumcheck/benches/sum_check_high_degree.rs +++ b/sumcheck/benches/sum_check_high_degree.rs @@ -44,6 +44,8 @@ fn sum_check_high_degree(c: &mut Criterion) { )| { let mls = vec![ml0, ml1, ml2, ml3, ml4]; let mut transcript = transcript; + let tensored_batching_randmoness = + rand_vector(evaluator.get_num_fractions().ilog2() as usize); sum_check_prove_higher_degree( &evaluator, @@ -53,6 +55,7 @@ fn sum_check_high_degree(c: &mut Criterion) { logup_randomness, mls, periodic_table, + &tensored_batching_randmoness, &mut transcript, ) }, diff --git a/sumcheck/benches/sum_check_plain.rs b/sumcheck/benches/sum_check_plain.rs index 14fd859ce..c031fedd2 100644 --- a/sumcheck/benches/sum_check_plain.rs +++ b/sumcheck/benches/sum_check_plain.rs @@ -11,7 +11,9 @@ use math::{fields::f64::BaseElement, FieldElement}; use rand_utils::{rand_value, rand_vector}; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; -use winter_sumcheck::{sumcheck_prove_plain, EqFunction, MultiLinearPoly}; +use winter_sumcheck::{ + sumcheck_prove_plain_batched, CircuitLayerPolys, EqFunction, MultiLinearPoly, +}; const LOG_POLY_SIZE: [usize; 2] = [18, 20]; @@ -26,13 +28,31 @@ fn sum_check_plain(c: &mut Criterion) { || { let transcript = DefaultRandomCoin::>::new(&[BaseElement::ZERO; 4]); - (setup_sum_check::(log_poly_size), transcript) + (setup_sum_check::(log_poly_size, 4), transcript) }, - |((claim, r_batch, p, q, eq), transcript)| { + |( + ( + claims, + evaluation_point, + r_batch, + inner_layers, + tensored_batching_randomness, + eq, + ), + transcript, + )| { let mut eq = eq; let mut transcript = transcript; - sumcheck_prove_plain(claim, r_batch, p, q, &mut eq, &mut transcript) + sumcheck_prove_plain_batched( + &claims, + &evaluation_point, + r_batch, + inner_layers, + &mut eq, + &tensored_batching_randomness, + &mut transcript, + ) }, BatchSize::SmallInput, ) @@ -44,22 +64,39 @@ fn sum_check_plain(c: &mut Criterion) { #[allow(clippy::type_complexity)] fn setup_sum_check( log_size: usize, -) -> (E, E, MultiLinearPoly, MultiLinearPoly, MultiLinearPoly) { + num_fractions: usize, +) -> (Vec, Vec, E, Vec>, Vec, MultiLinearPoly) { let n = 1 << (log_size + 1); - let p: Vec = rand_vector(n); - let q: Vec = rand_vector(n); // this will not generate the correct claim with overwhelming probability but should be fine // for benchmarking - let rand_pt = rand_vector(log_size); + let evaluation_point = rand_vector(log_size); let r_batch: E = rand_value(); - let claim: E = rand_value(); + let claims: Vec = vec![rand_value(); num_fractions]; - let p = MultiLinearPoly::from_evaluations(p); - let q = MultiLinearPoly::from_evaluations(q); - let eq = MultiLinearPoly::from_evaluations(EqFunction::new(rand_pt.into()).evaluations()); + let mut inner_layers = Vec::with_capacity(num_fractions); + for _ in 0..num_fractions { + let p: Vec = rand_vector(n); + let q: Vec = rand_vector(n); + let p = MultiLinearPoly::from_evaluations(p); + let q = MultiLinearPoly::from_evaluations(q); + let inner_layer = CircuitLayerPolys::from_mle(p, q); + inner_layers.push(inner_layer) + } + let eq = MultiLinearPoly::from_evaluations( + EqFunction::new(evaluation_point.clone().into()).evaluations(), + ); + let tensored_batching_randomness = + EqFunction::new(rand_vector::(num_fractions).into()).evaluations(); - (claim, r_batch, p, q, eq) + ( + claims, + evaluation_point, + r_batch, + inner_layers, + tensored_batching_randomness, + eq, + ) } criterion_group!(group, sum_check_plain); diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index b7f670a9d..d086e46f7 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -23,7 +23,7 @@ mod verifier; pub use verifier::*; mod univariate; -pub use univariate::{CompressedUnivariatePoly, CompressedUnivariatePolyEvals}; +pub use univariate::CompressedUnivariatePoly; mod multilinear; pub use multilinear::{inner_product, EqFunction, MultiLinearPoly}; @@ -37,7 +37,7 @@ pub use multilinear::{inner_product, EqFunction, MultiLinearPoly}; #[derive(Clone, Debug)] pub struct FinalOpeningClaim { pub eval_point: Vec, - pub openings: Vec, + pub openings: Vec>, } impl Serializable for FinalOpeningClaim { @@ -229,8 +229,8 @@ where /// Holds the output layer of an [`EvaluatedCircuit`]. #[derive(Clone, Debug)] pub struct CircuitOutput { - pub numerators: MultiLinearPoly, - pub denominators: MultiLinearPoly, + pub numerators: Vec>, + pub denominators: Vec>, } impl Serializable for CircuitOutput @@ -250,8 +250,8 @@ where { fn read_from(source: &mut R) -> Result { Ok(Self { - numerators: MultiLinearPoly::read_from(source)?, - denominators: MultiLinearPoly::read_from(source)?, + numerators: Vec::>::read_from(source)?, + denominators: Vec::>::read_from(source)?, }) } } @@ -268,14 +268,22 @@ fn comb_func(p0: E, p1: E, q0: E, q1: E, eq: E, r_batch: E) -> /// The non-linear composition polynomial of the LogUp-GKR protocol specific to the input layer. pub fn evaluate_composition_poly( eq_at_mu: &[E], - numerators: &[E], - denominators: &[E], + numerators_zero: &[E], + denominators_zero: &[E], + numerators_one: &[E], + denominators_one: &[E], eq_eval: E, r_sum_check: E, ) -> E { - numerators - .chunks(2) - .zip(denominators.chunks(2).zip(eq_at_mu.iter())) - .map(|(p, (q, eq_w))| *eq_w * comb_func(p[0], p[1], q[0], q[1], eq_eval, r_sum_check)) + numerators_zero + .iter() + .zip( + numerators_one + .iter() + .zip(denominators_zero.iter().zip(denominators_one.iter().zip(eq_at_mu.iter()))), + ) + .map(|(p0, (p1, (q0, (q1, eq_w))))| { + *eq_w * comb_func(*p0, *p1, *q0, *q1, eq_eval, r_sum_check) + }) .fold(E::ZERO, |acc, x| acc + x) } diff --git a/sumcheck/src/multilinear.rs b/sumcheck/src/multilinear.rs index df6177914..3d8c4a700 100644 --- a/sumcheck/src/multilinear.rs +++ b/sumcheck/src/multilinear.rs @@ -73,63 +73,19 @@ impl MultiLinearPoly { #[inline(always)] pub fn bind_least_significant_variable(&mut self, round_challenge: E) { let num_evals = self.evaluations.len() >> 1; - #[cfg(not(feature = "concurrent"))] - { - for i in 0..num_evals { - // SAFETY: This loops over [0, evaluations.len()/2). The largest value for `i` is - // `(evaluations.len() / 2) - 1`. Hence, the largest value for `(i<<1)` is - // `evaluations.len() - 2`, and largest value for `(i<<1) + 1` is `evaluations.len() - 1`. - let evaluations_2i = unsafe { *self.evaluations.get_unchecked(i << 1) }; - let evaluations_2i_plus_1 = - unsafe { *self.evaluations.get_unchecked((i << 1) + 1) }; - - self.evaluations[i] = - evaluations_2i + round_challenge * (evaluations_2i_plus_1 - evaluations_2i); - } - self.evaluations.truncate(num_evals); - } - #[cfg(feature = "concurrent")] - { - let mut result = unsafe { utils::uninit_vector(num_evals) }; - result.par_iter_mut().enumerate().for_each(|(i, ev)| { - // SAFETY: This loops over [0, evaluations.len()/2). The largest value for `i` is - // `(evaluations.len() / 2) - 1`. Hence, the largest value for `(i<<1)` is - // `evaluations.len() - 2`, and largest value for `(i<<1) + 1` is `evaluations.len() - 1`. - let evaluations_2i = unsafe { *self.evaluations.get_unchecked(i << 1) }; - let evaluations_2i_plus_1 = - unsafe { *self.evaluations.get_unchecked((i << 1) + 1) }; - - *ev = evaluations_2i + round_challenge * (evaluations_2i_plus_1 - evaluations_2i); - }); - self.evaluations = result - } - } + for i in 0..num_evals { + // SAFETY: This loops over [0, evaluations.len()/2). The largest value for `i` is + // `(evaluations.len() / 2) - 1`. Hence, the largest value for `i` is + // `(evaluations.len() / 2) - 2`, and largest value for `i + evaluations.len()/2` + // is `evaluations.len() - 1`. + let evaluations_i = unsafe { *self.evaluations.get_unchecked(i) }; + let evaluations_i_plus_num_evals = unsafe { *self.evaluations.get_unchecked(num_evals + i) }; - /// Given the multilinear polynomial $f(y_0, y_1, ..., y_{{\nu} - 1})$, returns two polynomials: - /// $f(0, y_1, ..., y_{{\nu} - 1})$ and $f(1, y_1, ..., y_{{\nu} - 1})$. - pub fn project_least_significant_variable(mut self) -> (Self, Self) { - let odds: Vec = self - .evaluations - .iter() - .enumerate() - .filter_map(|(idx, x)| if idx % 2 == 1 { Some(*x) } else { None }) - .collect(); - - // Builds the evens multilinear from the current `self.evaluations` buffer, which saves an - // allocation. - let evens = { - let evens_size = self.num_evaluations() / 2; - for write_idx in 0..evens_size { - let read_idx = write_idx * 2; - self.evaluations[write_idx] = self.evaluations[read_idx]; - } - self.evaluations.truncate(evens_size); - - self.evaluations - }; - - (Self::from_evaluations(evens), Self::from_evaluations(odds)) + self.evaluations[i] = + evaluations_i + round_challenge * (evaluations_i_plus_num_evals - evaluations_i); + } + self.evaluations.truncate(num_evals); } } @@ -284,7 +240,7 @@ fn compute_lagrange_basis_evals_at(query: &[E]) -> Vec { evals[0] = E::ONE; #[cfg(not(feature = "concurrent"))] let evals = { - for r_i in query.iter() { + for r_i in query.iter().rev() { let (left_evals, right_evals) = evals.split_at_mut(size); left_evals.iter_mut().zip(right_evals.iter_mut()).for_each(|(left, right)| { let factor = *left; @@ -299,7 +255,7 @@ fn compute_lagrange_basis_evals_at(query: &[E]) -> Vec { #[cfg(feature = "concurrent")] let evals = { - for r_i in query.iter() { + for r_i in query.iter().rev() { let (left_evals, right_evals) = evals.split_at_mut(size); left_evals .par_iter_mut() @@ -380,14 +336,14 @@ fn test_eq_function() { let r1 = rand_value(); let eq_function = EqFunction::new(smallvec![r0, r1]); - let expected = vec![(one - r0) * (one - r1), r0 * (one - r1), (one - r0) * r1, r0 * r1]; + let expected = vec![(one - r0) * (one - r1), (one - r0) * r1, r0 * (one - r1), r0 * r1]; assert_eq!(expected, eq_function.evaluations()); // Lagrange kernel evaluation is correct let q0 = rand_value(); let q1 = rand_value(); - let tensored_query = vec![(one - q0) * (one - q1), q0 * (one - q1), (one - q0) * q1, q0 * q1]; + let tensored_query = vec![(one - q0) * (one - q1), (one - q0) * q1, q0 * (one - q1), q0 * q1]; let expected = inner_product(&tensored_query, &eq_function.evaluations()); diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index 47be290d7..7b525b617 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -11,10 +11,10 @@ use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; -use super::SumCheckProverError; +use super::{compute_scaling_down_factors, to_coefficients, SumCheckProverError}; use crate::{ - evaluate_composition_poly, CompressedUnivariatePolyEvals, EqFunction, FinalOpeningClaim, - MultiLinearPoly, RoundProof, SumCheckProof, SumCheckRoundClaim, + evaluate_composition_poly, EqFunction, FinalOpeningClaim, MultiLinearPoly, RoundProof, + SumCheckProof, SumCheckRoundClaim, }; /// A sum-check prover for the input layer which can accommodate non-linear expressions in @@ -148,7 +148,102 @@ use crate::{ /// \right) /// $$ /// +/// +/// We now discuss a further optimization due to [2]. Suppose that we have a sum-check statment of +/// the following form: +/// +/// $$v_0=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{\nu - 1}\right);\left( x_0, \cdots, x_{\nu - 1}\right)\right) +/// C\left( x_0, \cdots, x_{\nu - 1} \right)$$ +/// +/// Then during round $i + 1$ of sum-check, the prover needs to send the following polynomial +/// +/// $$v_{i+1}(X)=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1},\alpha_i, \alpha_{i+1},\cdots\alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// We can write $v_{i+1}(X)$ as: +/// +/// $$v_{i+1}(X)=Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1} \right);\left(r_0,\cdots,r_{i-1}\right)\right) +/// \cdot Eq\left(\alpha_i ;X\right)\sum_{x}Eq\left(\left(\alpha_{i+1},\cdots\alpha_{\nu - 1}\right);\left( x_{i+1}, \cdots x_{\nu - 1}\right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// This means that $v_{i+1}(X)$ is the product of: +/// +/// 1. A constant polynomial: $Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right);\left( r_0, \cdots, r_{i-1} \right) \right)$ +/// 2. A linear polynomial: $Eq\left( \alpha_i ; X \right)$ +/// 3. A high degree polynomial: $\sum_{x} +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$ +/// +/// The advantage of the above decomposition is that the prover when computing $v_{i+1}(X)$ needs to sum over +/// +/// $$ +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// instead of +/// +/// $$ +/// Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1}, \alpha_i, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// which has the advantage of being of degree $1$ less and hence requires less work on the part of the prover. +/// +/// Thus, the prover computes the following polynomial +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// and then scales it in order to get +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right) \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// As the prover computes $v_{i+1}^{'}(X)$ in evaluation form and hence also $v_{i+1}(X)$, this +/// means that due to the degrees being off by $1$, the prover uses the linear factor in order to +/// obtain an additional evaluation point in order to be able to interpolate $v_{i+1}(X)$. +/// More precisely, we can get a root of $$v_{i+1}(X) = 0$$ by solving $$Eq\left( \alpha_i ; X \right) = 0$$ +/// The latter equation has as solution $$\mathsf{r} = \frac{1 - \alpha}{1 - 2\cdot\alpha}$$ +/// which is, except with negligible probability, an evaluation point not in the original +/// evaluation set and hence the prover is able to interpolate $v_{i+1}(X)$ and send it to +/// the verifier. +/// +/// Note that in order to avoid having to compute $\{Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ from $\{Eq\left( \left( \alpha_{i}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i}, \cdots x_{\nu - 1} \right) \right)\}$, or vice versa, we can write +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// as +/// +/// $$v_{i+1}^{'}(X) = \frac{1}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \sum_{x} +/// Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// Thus, $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ can be read from +/// $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{\nu - 1} \right);\left(x_{0}, \cdots x_{\nu - 1} \right) \right)\}$ +/// directly, at the cost of the relation between $v_{i+1}^{'}(X)$ and $v_{i+1}(X)$ becoming +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) \frac{Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right)}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// /// [1]: https://eprint.iacr.org/2023/1284 +/// [2]: https://eprint.iacr.org/2024/108 #[allow(clippy::too_many_arguments)] pub fn sum_check_prove_higher_degree< E: FieldElement, @@ -161,42 +256,56 @@ pub fn sum_check_prove_higher_degree< log_up_randomness: Vec, mut mls: Vec>, mut periodic_table: PeriodicTable, + tensored_circuits_batching: &[E], coin: &mut impl RandomCoin, ) -> Result, SumCheckProverError> { - let num_rounds = mls[0].num_variables(); + let num_rounds = mls[0].num_variables() - 1; let mut round_proofs = vec![]; - // split the evaluation point into two points of dimension mu and nu, respectively - let mu = evaluator.get_num_fractions().trailing_zeros() - 1; - let (evaluation_point_mu, evaluation_point_nu) = evaluation_point.split_at(mu as usize); - let eq_mu = EqFunction::ml_at(evaluation_point_mu.into()).evaluations().to_vec(); - let mut eq_nu = EqFunction::ml_at(evaluation_point_nu.into()); - + let eq_mle = EqFunction::ml_at(evaluation_point.clone().into()); // setup first round claim let mut current_round_claim = SumCheckRoundClaim { eval_point: vec![], claim }; // run the first round of the protocol - let round_poly_evals = sumcheck_round( - &eq_mu, + let mut round_poly_evals = sumcheck_round( + tensored_circuits_batching, evaluator, - &eq_nu, + &eq_mle, &mls, &periodic_table, &log_up_randomness, r_sum_check, ); - let round_poly_coefs = round_poly_evals.to_poly(current_round_claim.claim); + + // this will hold `Eq((\alpha_0, \cdots, \alpha_{i - 1});(r_0, \cdots, r_{i-1}))` + let mut scaling_up_factor = E::ONE; + // this will hold `Eq((\alpha_{0}, \cdots, \alpha_{i}); (0, \cdots, 0))` for all `i` + let scaling_down_factors = compute_scaling_down_factors(&evaluation_point); + // this is `\alpha_i` above + let mut alpha_i = evaluation_point[0]; + let scaling_down_factor = scaling_down_factors[0]; + let round_poly_coefs = to_coefficients( + &mut round_poly_evals, + current_round_claim.claim, + alpha_i, + scaling_down_factor, + scaling_up_factor, + ); // reseed with the s_0 polynomial coin.reseed(H::hash_elements(&round_poly_coefs.0)); round_proofs.push(RoundProof { round_poly_coefs }); - for i in 1..num_rounds { // generate random challenge r_i for the i-th round let round_challenge = coin.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + // update `scaling_up_factor` + alpha_i = evaluation_point[evaluation_point.len() + 1 - mls[0].num_variables()]; + scaling_up_factor *= + round_challenge * alpha_i + (E::ONE - round_challenge) * (E::ONE - alpha_i); + // compute the new reduced round claim let new_round_claim = reduce_claim(&round_proofs[i - 1], current_round_claim, round_challenge); @@ -204,17 +313,16 @@ pub fn sum_check_prove_higher_degree< // fold each multi-linear using the round challenge mls.iter_mut() .for_each(|ml| ml.bind_least_significant_variable(round_challenge)); - eq_nu.bind_least_significant_variable(round_challenge); // fold each periodic multi-linear using the round challenge periodic_table.bind_least_significant_variable(round_challenge); // run the i-th round of the protocol using the folded multi-linears for the new reduced // claim. This basically computes the s_i polynomial. - let round_poly_evals = sumcheck_round( - &eq_mu, + let mut round_poly_evals = sumcheck_round( + tensored_circuits_batching, evaluator, - &eq_nu, + &eq_mle, &mls, &periodic_table, &log_up_randomness, @@ -224,7 +332,14 @@ pub fn sum_check_prove_higher_degree< // update the claim current_round_claim = new_round_claim; - let round_poly_coefs = round_poly_evals.to_poly(current_round_claim.claim); + let alpha_i = evaluation_point[i]; + let round_poly_coefs = to_coefficients( + &mut round_poly_evals, + current_round_claim.claim, + alpha_i, + scaling_down_factors[i], + scaling_up_factor, + ); // reseed with the s_i polynomial coin.reseed(H::hash_elements(&round_poly_coefs.0)); @@ -239,15 +354,20 @@ pub fn sum_check_prove_higher_degree< // fold each multi-linear using the last random round challenge mls.iter_mut() .for_each(|ml| ml.bind_least_significant_variable(round_challenge)); - eq_nu.bind_least_significant_variable(round_challenge); + + // fold each periodic multi-linear using the round challenge + periodic_table.bind_least_significant_variable(round_challenge); let SumCheckRoundClaim { eval_point, claim: _claim } = reduce_claim(&round_proofs[num_rounds - 1], current_round_claim, round_challenge); - let openings = mls.iter_mut().map(|ml| ml.evaluations()[0]).collect(); + let openings: Vec = mls + .into_iter() + .flat_map(|ml| [ml.evaluations()[0], ml.evaluations()[1]]) + .collect(); Ok(SumCheckProof { - openings_claim: FinalOpeningClaim { eval_point, openings }, + openings_claim: FinalOpeningClaim { eval_point, openings: vec![openings] }, round_proofs, }) } @@ -293,111 +413,157 @@ pub fn sum_check_prove_higher_degree< /// the previous one using only additions. This is the purpose of `deltas`, to hold the increments /// added to each multi-linear to compute the evaluation at the next point, and `evals_x` to hold /// the current evaluation at $x$ in $\{2, ... , d_max\}$. +#[allow(clippy::too_many_arguments)] fn sumcheck_round( - eq_mu: &[E], + tensored_circuits_batching: &[E], evaluator: &impl LogUpGkrEvaluator::BaseField>, eq_ml: &MultiLinearPoly, mls: &[MultiLinearPoly], periodic_table: &PeriodicTable, log_up_randomness: &[E], r_sum_check: E, -) -> CompressedUnivariatePolyEvals { +) -> Vec { let num_mls = mls.len(); let num_periodic = periodic_table.num_columns(); let num_vars = mls[0].num_variables(); - let num_rounds = num_vars - 1; + let num_rounds = num_vars - 1 - 1; #[cfg(not(feature = "concurrent"))] let evaluations = { - let mut evals_one = vec![E::ZERO; num_mls]; - let mut evals_zero = vec![E::ZERO; num_mls]; - let mut evals_x = vec![E::ZERO; num_mls]; - - let mut evals_periodic_one = vec![E::ZERO; num_periodic]; - let mut evals_periodic_zero = vec![E::ZERO; num_periodic]; - let mut evals_periodic_x = vec![E::ZERO; num_periodic]; - let mut eq_x = E::ZERO; - - let mut deltas = vec![E::ZERO; num_mls]; - let mut deltas_periodic = vec![E::ZERO; num_periodic]; - let mut eq_delta = E::ZERO; - - let mut numerators = vec![E::ZERO; evaluator.get_num_fractions()]; - let mut denominators = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut evals_one_zero = vec![E::ZERO; num_mls]; + let mut evals_one_one = vec![E::ZERO; num_mls]; + let mut evals_zero_zero = vec![E::ZERO; num_mls]; + let mut evals_zero_one = vec![E::ZERO; num_mls]; + + let mut evals_x_zero = vec![E::ZERO; num_mls]; + let mut evals_x_one = vec![E::ZERO; num_mls]; + + let mut evals_periodic_zero_zero = vec![E::ZERO; num_periodic]; + let mut evals_periodic_zero_one = vec![E::ZERO; num_periodic]; + let mut evals_periodic_one_zero = vec![E::ZERO; num_periodic]; + let mut evals_periodic_one_one = vec![E::ZERO; num_periodic]; + + let mut evals_periodic_x_zero = vec![E::ZERO; num_periodic]; + let mut evals_periodic_x_one = vec![E::ZERO; num_periodic]; + + let mut deltas_zero = vec![E::ZERO; num_mls]; + let mut deltas_one = vec![E::ZERO; num_mls]; + let mut deltas_periodic_zero = vec![E::ZERO; num_periodic]; + let mut deltas_periodic_one = vec![E::ZERO; num_periodic]; + + let mut numerators_zero = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut denominators_zero = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut numerators_one = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut denominators_one = vec![E::ZERO; evaluator.get_num_fractions()]; (0..1 << num_rounds) .map(|i| { - let mut total_evals = vec![E::ZERO; evaluator.max_degree()]; - + let mut total_evals = vec![E::ZERO; evaluator.max_degree() - 1]; for (j, ml) in mls.iter().enumerate() { - evals_zero[j] = ml.evaluations()[2 * i]; - evals_one[j] = ml.evaluations()[2 * i + 1]; + evals_zero_zero[j] = ml.evaluations()[2 * i]; + evals_zero_one[j] = ml.evaluations()[2 * i + 1]; + evals_one_zero[j] = ml.evaluations()[2 * i + 2 * (1 << num_rounds)]; + evals_one_one[j] = ml.evaluations()[2 * i + 2 * (1 << num_rounds) + 1]; } - - let eq_at_zero = eq_ml.evaluations()[2 * i]; - let eq_at_one = eq_ml.evaluations()[2 * i + 1]; + let eq_at_zero = eq_ml.evaluations()[i]; // add evaluation of periodic columns - periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero); - periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_one); + periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero_zero); + periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_zero_one); + periodic_table.fill_periodic_values_at( + 2 * i + 2 * (1 << num_rounds), + &mut evals_periodic_one_zero, + ); + periodic_table.fill_periodic_values_at( + 2 * i + 2 * (1 << num_rounds) + 1, + &mut evals_periodic_one_one, + ); - // compute the evaluation at 1 + // compute the evaluation at 0 + evaluator.evaluate_query( + &evals_zero_zero, + &evals_periodic_zero_zero, + log_up_randomness, + &mut numerators_zero, + &mut denominators_zero, + ); evaluator.evaluate_query( - &evals_one, - &evals_periodic_one, + &evals_zero_one, + &evals_periodic_zero_one, log_up_randomness, - &mut numerators, - &mut denominators, + &mut numerators_one, + &mut denominators_one, ); total_evals[0] = evaluate_composition_poly( - eq_mu, - &numerators, - &denominators, - eq_at_one, + tensored_circuits_batching, + &numerators_zero, + &denominators_zero, + &numerators_one, + &denominators_one, + eq_at_zero, r_sum_check, ); - // compute the evaluations at 2, ..., d_max points + // compute the evaluations at 2, ..., d_max - 1 points for i in 0..num_mls { - deltas[i] = evals_one[i] - evals_zero[i]; - evals_x[i] = evals_one[i]; + deltas_zero[i] = evals_one_zero[i] - evals_zero_zero[i]; + evals_x_zero[i] = evals_one_zero[i]; + deltas_one[i] = evals_one_one[i] - evals_zero_one[i]; + evals_x_one[i] = evals_one_one[i]; } for i in 0..num_periodic { - deltas_periodic[i] = evals_periodic_one[i] - evals_periodic_zero[i]; - evals_periodic_x[i] = evals_periodic_one[i]; + deltas_periodic_zero[i] = + evals_periodic_one_zero[i] - evals_periodic_zero_zero[i]; + evals_periodic_x_zero[i] = evals_periodic_one_zero[i]; + deltas_periodic_one[i] = evals_periodic_one_one[i] - evals_periodic_zero_one[i]; + evals_periodic_x_one[i] = evals_periodic_one_one[i]; } - eq_delta = eq_at_one - eq_at_zero; - eq_x = eq_at_one; for e in total_evals.iter_mut().skip(1) { - evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| { + evals_x_zero.iter_mut().zip(deltas_zero.iter()).for_each(|(evx, delta)| { *evx += *delta; }); - evals_periodic_x.iter_mut().zip(deltas_periodic.iter()).for_each( + evals_periodic_x_zero.iter_mut().zip(deltas_periodic_zero.iter()).for_each( + |(evx, delta)| { + *evx += *delta; + }, + ); + evals_x_one.iter_mut().zip(deltas_one.iter()).for_each(|(evx, delta)| { + *evx += *delta; + }); + evals_periodic_x_one.iter_mut().zip(deltas_periodic_one.iter()).for_each( |(evx, delta)| { *evx += *delta; }, ); - eq_x += eq_delta; evaluator.evaluate_query( - &evals_x, - &evals_periodic_x, + &evals_x_zero, + &evals_periodic_x_zero, + log_up_randomness, + &mut numerators_zero, + &mut denominators_zero, + ); + evaluator.evaluate_query( + &evals_x_one, + &evals_periodic_x_one, log_up_randomness, - &mut numerators, - &mut denominators, + &mut numerators_one, + &mut denominators_one, ); *e = evaluate_composition_poly( - eq_mu, - &numerators, - &denominators, - eq_x, + tensored_circuits_batching, + &numerators_zero, + &denominators_zero, + &numerators_one, + &denominators_one, + eq_at_zero, r_sum_check, ); } total_evals }) - .fold(vec![E::ZERO; evaluator.max_degree()], |mut acc, poly_eval| { + .fold(vec![E::ZERO; evaluator.max_degree() - 1], |mut acc, poly_eval| { acc.iter_mut().zip(poly_eval.iter()).for_each(|(a, b)| { *a += *b; }); @@ -411,118 +577,185 @@ fn sumcheck_round( .fold( || { ( + vec![E::ZERO; num_mls], + vec![E::ZERO; num_mls], + vec![E::ZERO; num_mls], vec![E::ZERO; num_mls], vec![E::ZERO; num_mls], vec![E::ZERO; num_mls], vec![E::ZERO; num_periodic], vec![E::ZERO; num_periodic], vec![E::ZERO; num_periodic], - vec![E::ZERO; evaluator.max_degree()], - vec![E::ZERO; evaluator.get_num_fractions()], - vec![E::ZERO; evaluator.get_num_fractions()], + vec![E::ZERO; num_periodic], + vec![E::ZERO; num_periodic], + vec![E::ZERO; num_periodic], + vec![E::ZERO; num_mls], vec![E::ZERO; num_mls], vec![E::ZERO; num_periodic], + vec![E::ZERO; num_periodic], + vec![E::ZERO; evaluator.get_num_fractions()], + vec![E::ZERO; evaluator.get_num_fractions()], + vec![E::ZERO; evaluator.get_num_fractions()], + vec![E::ZERO; evaluator.get_num_fractions()], + vec![E::ZERO; evaluator.max_degree() - 1], ) }, |( - mut evals_zero, - mut evals_one, - mut evals_x, - mut evals_periodic_zero, - mut evals_periodic_one, - mut evals_periodic_x, + mut evals_one_zero, + mut evals_one_one, + mut evals_zero_zero, + mut evals_zero_one, + mut evals_x_zero, + mut evals_x_one, + mut evals_periodic_zero_zero, + mut evals_periodic_zero_one, + mut evals_periodic_one_zero, + mut evals_periodic_one_one, + mut evals_periodic_x_zero, + mut evals_periodic_x_one, + mut deltas_zero, + mut deltas_one, + mut deltas_periodic_zero, + mut deltas_periodic_one, + mut numerators_zero, + mut numerators_one, + mut denominators_zero, + mut denominators_one, mut poly_evals, - mut numerators, - mut denominators, - mut deltas, - mut deltas_periodic, ), i| { for (j, ml) in mls.iter().enumerate() { - evals_zero[j] = ml.evaluations()[2 * i]; - evals_one[j] = ml.evaluations()[2 * i + 1]; + evals_zero_zero[j] = ml.evaluations()[2 * i]; + evals_zero_one[j] = ml.evaluations()[2 * i + 1]; + evals_one_zero[j] = ml.evaluations()[2 * i + 2 * (1 << num_rounds)]; + evals_one_one[j] = ml.evaluations()[2 * i + 2 * (1 << num_rounds) + 1]; } - let eq_at_zero = eq_ml.evaluations()[2 * i]; - let eq_at_one = eq_ml.evaluations()[2 * i + 1]; + let eq_at_zero = eq_ml.evaluations()[i]; // add evaluation of periodic columns - periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero); - periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_one); + periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero_zero); + periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_zero_one); + periodic_table.fill_periodic_values_at( + 2 * i + 2 * (1 << num_rounds), + &mut evals_periodic_one_zero, + ); + periodic_table.fill_periodic_values_at( + 2 * i + 2 * (1 << num_rounds) + 1, + &mut evals_periodic_one_one, + ); - // compute the evaluation at 1 + // compute the evaluation at 0 + evaluator.evaluate_query( + &evals_zero_zero, + &evals_periodic_zero_zero, + log_up_randomness, + &mut numerators_zero, + &mut denominators_zero, + ); evaluator.evaluate_query( - &evals_one, - &evals_periodic_one, + &evals_zero_one, + &evals_periodic_zero_one, log_up_randomness, - &mut numerators, - &mut denominators, + &mut numerators_one, + &mut denominators_one, ); - poly_evals[0] = evaluate_composition_poly( - eq_mu, - &numerators, - &denominators, - eq_at_one, + poly_evals[0] += evaluate_composition_poly( + tensored_circuits_batching, + &numerators_zero, + &denominators_zero, + &numerators_one, + &denominators_one, + eq_at_zero, r_sum_check, ); - // compute the evaluations at 2, ..., d_max points + // compute the evaluations at 2, ..., d_max - 1 points for i in 0..num_mls { - deltas[i] = evals_one[i] - evals_zero[i]; - evals_x[i] = evals_one[i]; + deltas_zero[i] = evals_one_zero[i] - evals_zero_zero[i]; + evals_x_zero[i] = evals_one_zero[i]; + deltas_one[i] = evals_one_one[i] - evals_zero_one[i]; + evals_x_one[i] = evals_one_one[i]; } for i in 0..num_periodic { - deltas_periodic[i] = evals_periodic_one[i] - evals_periodic_zero[i]; - evals_periodic_x[i] = evals_periodic_one[i]; + deltas_periodic_zero[i] = + evals_periodic_one_zero[i] - evals_periodic_zero_zero[i]; + evals_periodic_x_zero[i] = evals_periodic_one_zero[i]; + deltas_periodic_one[i] = evals_periodic_one_one[i] - evals_periodic_zero_one[i]; + evals_periodic_x_one[i] = evals_periodic_one_one[i]; } - let eq_delta = eq_at_one - eq_at_zero; - let mut eq_x = eq_at_one; for e in poly_evals.iter_mut().skip(1) { - evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| { + evals_x_zero.iter_mut().zip(deltas_zero.iter()).for_each(|(evx, delta)| { *evx += *delta; }); - evals_periodic_x.iter_mut().zip(deltas_periodic.iter()).for_each( + evals_periodic_x_zero.iter_mut().zip(deltas_periodic_zero.iter()).for_each( + |(evx, delta)| { + *evx += *delta; + }, + ); + evals_x_one.iter_mut().zip(deltas_one.iter()).for_each(|(evx, delta)| { + *evx += *delta; + }); + evals_periodic_x_one.iter_mut().zip(deltas_periodic_one.iter()).for_each( |(evx, delta)| { *evx += *delta; }, ); - eq_x += eq_delta; evaluator.evaluate_query( - &evals_x, - &evals_periodic_x, + &evals_x_zero, + &evals_periodic_x_zero, log_up_randomness, - &mut numerators, - &mut denominators, + &mut numerators_zero, + &mut denominators_zero, ); - *e = evaluate_composition_poly( - eq_mu, - &numerators, - &denominators, - eq_x, + evaluator.evaluate_query( + &evals_x_one, + &evals_periodic_x_one, + log_up_randomness, + &mut numerators_one, + &mut denominators_one, + ); + *e += evaluate_composition_poly( + tensored_circuits_batching, + &numerators_zero, + &denominators_zero, + &numerators_one, + &denominators_one, + eq_at_zero, r_sum_check, ); } ( - evals_zero, - evals_one, - evals_x, - evals_periodic_zero, - evals_periodic_one, - evals_periodic_x, + evals_one_zero, + evals_one_one, + evals_zero_zero, + evals_zero_one, + evals_x_zero, + evals_x_one, + evals_periodic_zero_zero, + evals_periodic_zero_one, + evals_periodic_one_zero, + evals_periodic_one_one, + evals_periodic_x_zero, + evals_periodic_x_one, + deltas_zero, + deltas_one, + deltas_periodic_zero, + deltas_periodic_one, + numerators_zero, + numerators_one, + denominators_zero, + denominators_one, poly_evals, - numerators, - denominators, - deltas, - deltas_periodic, ) }, ) - .map(|(_, _, _, poly_evals, ..)| poly_evals) + .map(|(.., poly_evals)| poly_evals) .reduce( - || vec![E::ZERO; evaluator.max_degree()], + || vec![E::ZERO; evaluator.max_degree() - 1], |mut acc, poly_eval| { acc.iter_mut().zip(poly_eval.iter()).for_each(|(a, b)| { *a += *b; @@ -531,7 +764,7 @@ fn sumcheck_round( }, ); - CompressedUnivariatePolyEvals(evaluations.into()) + evaluations } /// Reduces an old claim to a new claim using the round challenge. diff --git a/sumcheck/src/prover/mod.rs b/sumcheck/src/prover/mod.rs index 13d35e551..98f18b589 100644 --- a/sumcheck/src/prover/mod.rs +++ b/sumcheck/src/prover/mod.rs @@ -4,10 +4,236 @@ // LICENSE file in the root directory of this source tree. mod high_degree; +use alloc::{fmt, vec::Vec}; +use core::{fmt::Formatter, ops::Add}; + pub use high_degree::sum_check_prove_higher_degree; +use crate::CompressedUnivariatePoly; + mod plain; -pub use plain::sumcheck_prove_plain; +use math::{batch_inversion, FieldElement}; +pub use plain::{sumcheck_prove_plain_batched, sumcheck_prove_plain_batched_serial}; mod error; pub use error::SumCheckProverError; +use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; + +use crate::MultiLinearPoly; + +// CIRCUIT LAYER POLYS +// =============================================================================================== + +/// Holds a layer of an [`EvaluatedCircuit`] in a representation amenable to proving circuit +/// evaluation using GKR. +#[derive(Clone, Debug)] +pub struct CircuitLayerPolys { + pub numerators: MultiLinearPoly, + pub denominators: MultiLinearPoly, +} + +impl Serializable for CircuitLayerPolys +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + let Self { numerators, denominators } = self; + numerators.write_into(target); + denominators.write_into(target); + } +} + +impl Deserializable for CircuitLayerPolys +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + numerators: MultiLinearPoly::read_from(source)?, + denominators: MultiLinearPoly::read_from(source)?, + }) + } +} + +// CIRCUIT LAYER POLYS +// =============================================================================================== + +impl CircuitLayerPolys +where + E: FieldElement, +{ + pub fn from_circuit_layer(layers: &[Vec>]) -> Vec { + let mut result = vec![]; + for layer in layers { + result.push(Self::from_wires(layer.clone())) + } + result + } + + pub fn from_wires(wires: Vec>) -> Self { + let mut numerators = Vec::new(); + let mut denominators = Vec::new(); + + for wire in wires { + numerators.push(wire.numerator); + denominators.push(wire.denominator); + } + + Self { + numerators: MultiLinearPoly::from_evaluations(numerators), + denominators: MultiLinearPoly::from_evaluations(denominators), + } + } + + pub fn from_mle(numerators: MultiLinearPoly, denominators: MultiLinearPoly) -> Self { + Self { numerators, denominators } + } +} + +// CIRCUIT LAYER +// =============================================================================================== + +/// Represents a layer in a [`EvaluatedCircuit`]. +/// +/// A layer is made up of a set of `n` wires, where `n` is a power of two. This is the natural +/// circuit representation of a layer, where each consecutive pair of wires are summed to yield a +/// wire in the subsequent layer of an [`EvaluatedCircuit`]. +/// +/// Note that a [`Layer`] needs to be first converted to a [`LayerPolys`] before the evaluation of +/// the layer can be proved using GKR. +#[derive(Debug)] +pub struct CircuitLayer { + wires: Vec>, +} + +impl CircuitLayer { + /// Creates a new [`Layer`] from a set of projective coordinates. + /// + /// Panics if the number of projective coordinates is not a power of two. + pub fn new(wires: Vec>) -> Self { + assert!(wires.len().is_power_of_two()); + + Self { wires } + } + + /// Returns the wires that make up this circuit layer. + pub fn wires(&self) -> &[CircuitWire] { + &self.wires + } + + /// Returns the number of wires in the layer. + pub fn num_wires(&self) -> usize { + self.wires.len() + } +} + +// CIRCUIT WIRE +// =============================================================================================== + +/// Represents a fraction `numerator / denominator` as a pair `(numerator, denominator)`. This is +/// the type for the gates' inputs in [`prover::EvaluatedCircuit`]. +/// +/// Hence, addition is defined in the natural way fractions are added together: `a/b + c/d = (ad + +/// bc) / bd`. +#[derive(Clone, Copy)] +pub struct CircuitWire { + numerator: E, + denominator: E, +} + +impl CircuitWire +where + E: FieldElement, +{ + /// Creates new projective coordinates from a numerator and a denominator. + pub fn new(numerator: E, denominator: E) -> Self { + assert_ne!(denominator, E::ZERO); + + Self { numerator, denominator } + } +} + +impl Add for CircuitWire +where + E: FieldElement, +{ + type Output = Self; + + fn add(self, other: Self) -> Self { + let numerator = self.numerator * other.denominator + other.numerator * self.denominator; + let denominator = self.denominator * other.denominator; + + Self::new(numerator, denominator) + } +} + +impl fmt::Debug for CircuitWire { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{} / {}", self.numerator, self.denominator) + } +} + +// HELPER +// =============================================================================================== + +/// Takes the evaluation of the polynomial $v_{i+1}^{'}(X)$ defined by +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// and computes the interpolation of the $v_{i+1}(X)$ polynomial defined by +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) \frac{Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right)}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// The function returns a `CompressedUnivariatePoly` instead of the full list of coefficients. +fn to_coefficients( + round_poly_evals: &mut [E], + claim: E, + alpha: E, + scaling_down_factor: E, + scaling_up_factor: E, +) -> CompressedUnivariatePoly { + let a = scaling_down_factor; + round_poly_evals.iter_mut().for_each(|e| *e *= scaling_up_factor); + + let mut round_poly_evaluations = Vec::with_capacity(round_poly_evals.len() + 1); + round_poly_evaluations.push(round_poly_evals[0] * compute_weight(alpha, E::ZERO) * a); + round_poly_evaluations.push(claim - round_poly_evaluations[0]); + + for (x, eval) in round_poly_evals.iter().skip(1).enumerate() { + round_poly_evaluations.push(*eval * compute_weight(alpha, E::from(x as u32 + 2)) * a) + } + + let root = (E::ONE - alpha) / (E::ONE - alpha.double()); + + CompressedUnivariatePoly::interpolate_equidistant_points(&round_poly_evaluations, root) +} + +/// Computes +/// +/// $$ +/// Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right) +/// $$ +/// +/// given $(\alpha_0, \cdots, \alpha_{\nu - 1})$ for all $i$ in $0, \cdots, \nu - 1$. +fn compute_scaling_down_factors(gkr_point: &[E]) -> Vec { + let cumulative_product: Vec = gkr_point + .iter() + .scan(E::ONE, |acc, &x| { + *acc *= E::ONE - x; + Some(*acc) + }) + .collect(); + batch_inversion(&cumulative_product) +} + +/// Computes $EQ(x; \alpha)$. +fn compute_weight(alpha: E, x: E) -> E { + x * alpha + (E::ONE - x) * (E::ONE - alpha) +} diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index e0092cf10..617736e79 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -3,17 +3,17 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. +use alloc::vec::Vec; + use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; -use smallvec::smallvec; -use super::SumCheckProverError; -use crate::{ - comb_func, CompressedUnivariatePolyEvals, FinalOpeningClaim, MultiLinearPoly, RoundProof, - SumCheckProof, +use super::{ + compute_scaling_down_factors, to_coefficients, CircuitLayerPolys, SumCheckProverError, }; +use crate::{comb_func, FinalOpeningClaim, MultiLinearPoly, RoundProof, SumCheckProof}; /// Sum-check prover for non-linear multivariate polynomial of the simple LogUp-GKR. /// @@ -46,142 +46,364 @@ use crate::{ /// /// Note that the degree of the non-linear composition polynomial is 3. /// +/// +/// We now discuss a further optimization due to [2]. Suppose that we have a sum-check statment of +/// the following form: +/// +/// $$v_0=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{\nu - 1}\right);\left( x_0, \cdots, x_{\nu - 1}\right)\right) +/// C\left( x_0, \cdots, x_{\nu - 1} \right)$$ +/// +/// Then during round $i + 1$ of sum-check, the prover needs to send the following polynomial +/// +/// $$v_{i+1}(X)=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1},\alpha_i, \alpha_{i+1},\cdots\alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// We can write $v_{i+1}(X)$ as: +/// +/// $$v_{i+1}(X)=Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1} \right);\left(r_0,\cdots,r_{i-1}\right)\right) +/// \cdot Eq\left(\alpha_i ;X\right)\sum_{x}Eq\left(\left(\alpha_{i+1},\cdots\alpha_{\nu - 1}\right);\left( x_{i+1}, \cdots x_{\nu - 1}\right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// This means that $v_{i+1}(X)$ is the product of: +/// +/// 1. A constant polynomial: $Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right);\left( r_0, \cdots, r_{i-1} \right) \right)$ +/// 2. A linear polynomial: $Eq\left( \alpha_i ; X \right)$ +/// 3. A high degree polynomial: $\sum_{x} +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$ +/// +/// The advantage of the above decomposition is that the prover when computing $v_{i+1}(X)$ needs to sum over +/// +/// $$ +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// instead of +/// +/// $$ +/// Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1}, \alpha_i, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// which has the advantage of being of degree $1$ less and hence requires less work on the part of the prover. +/// +/// Thus, the prover computes the following polynomial +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// and then scales it in order to get +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right) \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// As the prover computes $v_{i+1}^{'}(X)$ in evaluation form and hence also $v_{i+1}(X)$, this +/// means that due to the degrees being off by $1$, the prover uses the linear factor in order to +/// obtain an additional evaluation point in order to be able to interpolate $v_{i+1}(X)$. +/// More precisely, we can get a root of $$v_{i+1}(X) = 0$$ by solving $$Eq\left( \alpha_i ; X \right) = 0$$ +/// The latter equation has as solution $$\mathsf{r} = \frac{1 - \alpha}{1 - 2\cdot\alpha}$$ +/// which is, except with negligible probability, an evaluation point not in the original +/// evaluation set and hence the prover is able to interpolate $v_{i+1}(X)$ and send it to +/// the verifier. +/// +/// Note that in order to avoid having to compute $\{Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ from $\{Eq\left( \left( \alpha_{i}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i}, \cdots x_{\nu - 1} \right) \right)\}$, or vice versa, we can write +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// as +/// +/// $$v_{i+1}^{'}(X) = \frac{1}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \sum_{x} +/// Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// Thus, $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ can be read from +/// $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{\nu - 1} \right);\left(x_{0}, \cdots x_{\nu - 1} \right) \right)\}$ +/// directly, at the cost of the relation between $v_{i+1}^{'}(X)$ and $v_{i+1}(X)$ becoming +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) \frac{Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right)}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// /// [1]: https://eprint.iacr.org/2023/1284 -#[allow(clippy::too_many_arguments)] -pub fn sumcheck_prove_plain>( - mut claim: E, +/// [2]: https://eprint.iacr.org/2024/108 +pub fn sumcheck_prove_plain_batched>( + claims: &[E], + gkr_point: &[E], r_batch: E, - p: MultiLinearPoly, - q: MultiLinearPoly, + mut inner_layers: Vec>, eq: &mut MultiLinearPoly, + tensored_batching_randomness: &[E], transcript: &mut impl RandomCoin, ) -> Result, SumCheckProverError> { let mut round_proofs = vec![]; let mut challenges = vec![]; - // construct the vector of multi-linear polynomials - let (mut p0, mut p1) = p.project_least_significant_variable(); - let (mut q0, mut q1) = q.project_least_significant_variable(); + let mut batched_claim_across_circuits = claims + .iter() + .zip(tensored_batching_randomness.iter()) + .fold(E::ZERO, |acc, (&claim_for_circuit_i, &randomness_for_circuit_i)| { + acc + claim_for_circuit_i * randomness_for_circuit_i + }); + + let scaling_down_factors = compute_scaling_down_factors(gkr_point); + let mut scaling_up_factor = E::ONE; + + let num_sum_check_rounds = inner_layers[0].numerators.num_variables() - 1; + for i in 0..num_sum_check_rounds { + let len = inner_layers[0].numerators.num_evaluations() / 4; + + #[cfg(feature = "concurrent")] + let (all_round_poly_eval_at_0, all_round_poly_eval_at_2) = inner_layers + .par_iter() + .zip(tensored_batching_randomness.par_iter()) + .fold( + || (E::ZERO, E::ZERO), + |(_acc_eval_0, _acc_eval_2), (inner_layer, batching_randomness)| { + let (round_poly_eval_at_0, round_poly_eval_at_2) = + (0..len).fold((E::ZERO, E::ZERO), |(a, b), k| { + let p0_i_0 = inner_layer.numerators[2 * k]; + let p0_i_1 = inner_layer.numerators[2 * k + 1]; + let p1_i_0 = inner_layer.numerators[2 * (k + len)]; + let p1_i_1 = inner_layer.numerators[2 * (k + len) + 1]; + let q0_i_0 = inner_layer.denominators[2 * k]; + let q0_i_1 = inner_layer.denominators[2 * k + 1]; + let q1_i_0 = inner_layer.denominators[2 * (k + len)]; + let q1_i_1 = inner_layer.denominators[2 * (k + len) + 1]; + let round_poly_eval_at_0 = + comb_func(p0_i_0, p0_i_1, q0_i_0, q0_i_1, eq[k], r_batch); + + let p0_delta = p1_i_0 - p0_i_0; + let p1_delta = p1_i_1 - p0_i_1; + let q0_delta = q1_i_0 - q0_i_0; + let q1_delta = q1_i_1 - q0_i_1; + + let p0_eval_at_2 = p1_i_0 + p0_delta; + let p1_eval_at_2 = p1_i_1 + p1_delta; + let q0_eval_at_2 = q1_i_0 + q0_delta; + let q1_eval_at_2 = q1_i_1 + q1_delta; + let round_poly_eval_at_2 = comb_func( + p0_eval_at_2, + p1_eval_at_2, + q0_eval_at_2, + q1_eval_at_2, + eq[k], + r_batch, + ); + + (round_poly_eval_at_0 + a, round_poly_eval_at_2 + b) + }); - for _ in 0..p0.num_variables() { - let len = p0.num_evaluations() / 2; + let tmp_round_poly_eval_at_1 = round_poly_eval_at_0 * *batching_randomness; + let tmp_round_poly_eval_at_2 = round_poly_eval_at_2 * *batching_randomness; + + (tmp_round_poly_eval_at_1, tmp_round_poly_eval_at_2) + }, + ) + .reduce(|| (E::ZERO, E::ZERO), |(a0, b0), (a1, b1)| (a0 + a1, b0 + b1)); #[cfg(not(feature = "concurrent"))] - let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = (0..len).fold( - (E::ZERO, E::ZERO, E::ZERO), - |(acc_point_1, acc_point_2, acc_point_3), i| { - let round_poly_eval_at_1 = comb_func( - p0[2 * i + 1], - p1[2 * i + 1], - q0[2 * i + 1], - q1[2 * i + 1], - eq[2 * i + 1], - r_batch, - ); - - let p0_delta = p0[2 * i + 1] - p0[2 * i]; - let p1_delta = p1[2 * i + 1] - p1[2 * i]; - let q0_delta = q0[2 * i + 1] - q0[2 * i]; - let q1_delta = q1[2 * i + 1] - q1[2 * i]; - let eq_delta = eq[2 * i + 1] - eq[2 * i]; - - let mut p0_eval_at_x = p0[2 * i + 1] + p0_delta; - let mut p1_eval_at_x = p1[2 * i + 1] + p1_delta; - let mut q0_eval_at_x = q0[2 * i + 1] + q0_delta; - let mut q1_eval_at_x = q1[2 * i + 1] + q1_delta; - let mut eq_evx = eq[2 * i + 1] + eq_delta; - let round_poly_eval_at_2 = comb_func( - p0_eval_at_x, - p1_eval_at_x, - q0_eval_at_x, - q1_eval_at_x, - eq_evx, - r_batch, - ); - - p0_eval_at_x += p0_delta; - p1_eval_at_x += p1_delta; - q0_eval_at_x += q0_delta; - q1_eval_at_x += q1_delta; - eq_evx += eq_delta; - let round_poly_eval_at_3 = comb_func( - p0_eval_at_x, - p1_eval_at_x, - q0_eval_at_x, - q1_eval_at_x, - eq_evx, - r_batch, - ); - - ( - round_poly_eval_at_1 + acc_point_1, - round_poly_eval_at_2 + acc_point_2, - round_poly_eval_at_3 + acc_point_3, - ) - }, + let (all_round_poly_eval_at_0, all_round_poly_eval_at_2) = + inner_layers.iter().zip(tensored_batching_randomness).fold( + (E::ZERO, E::ZERO), + |(eval_poly0, eval_poly2), (inner_layer, batching_randomness)| { + let (round_poly_eval_at_0, round_poly_eval_at_2) = + (0..len).fold((E::ZERO, E::ZERO), |(acc_point_0, acc_point_2), k| { + let p0_i_0 = inner_layer.numerators[2 * k]; + let p0_i_1 = inner_layer.numerators[2 * k + 1]; + let p1_i_0 = inner_layer.numerators[2 * (k + len)]; + let p1_i_1 = inner_layer.numerators[2 * (k + len) + 1]; + let q0_i_0 = inner_layer.denominators[2 * k]; + let q0_i_1 = inner_layer.denominators[2 * k + 1]; + let q1_i_0 = inner_layer.denominators[2 * (k + len)]; + let q1_i_1 = inner_layer.denominators[2 * (k + len) + 1]; + let round_poly_eval_at_0 = + comb_func(p0_i_0, p0_i_1, q0_i_0, q0_i_1, eq[k], r_batch); + + let p0_delta = p1_i_0 - p0_i_0; + let p1_delta = p1_i_1 - p0_i_1; + let q0_delta = q1_i_0 - q0_i_0; + let q1_delta = q1_i_1 - q0_i_1; + + let p0_eval_at_2 = p1_i_0 + p0_delta; + let p1_eval_at_2 = p1_i_1 + p1_delta; + let q0_eval_at_2 = q1_i_0 + q0_delta; + let q1_eval_at_2 = q1_i_1 + q1_delta; + let round_poly_eval_at_2 = comb_func( + p0_eval_at_2, + p1_eval_at_2, + q0_eval_at_2, + q1_eval_at_2, + eq[k], + r_batch, + ); + + (round_poly_eval_at_0 + acc_point_0, round_poly_eval_at_2 + acc_point_2) + }); + + ( + eval_poly0 + round_poly_eval_at_0 * *batching_randomness, + eval_poly2 + round_poly_eval_at_2 * *batching_randomness, + ) + }, + ); + + let alpha_i = gkr_point[i]; + let compressed_round_poly = to_coefficients( + &mut [all_round_poly_eval_at_0, all_round_poly_eval_at_2], + batched_claim_across_circuits, + alpha_i, + scaling_down_factors[i], + scaling_up_factor, ); + // reseed with the s_i polynomial + transcript.reseed(H::hash_elements(&compressed_round_poly.0)); + let round_proof = RoundProof { + round_poly_coefs: compressed_round_poly.clone(), + }; + + let round_challenge = + transcript.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + + #[cfg(not(feature = "concurrent"))] + for inner_layer in inner_layers.iter_mut() { + // fold each multi-linear using the round challenge + inner_layer.numerators.bind_least_significant_variable(round_challenge); + inner_layer.denominators.bind_least_significant_variable(round_challenge); + } + #[cfg(feature = "concurrent")] - let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = (0..len) - .into_par_iter() - .fold( - || (E::ZERO, E::ZERO, E::ZERO), - |(a, b, c), i| { - let round_poly_eval_at_1 = comb_func( - p0[2 * i + 1], - p1[2 * i + 1], - q0[2 * i + 1], - q1[2 * i + 1], - eq[2 * i + 1], - r_batch, - ); + inner_layers.par_iter_mut().for_each(|inner_layer| { + // fold each multi-linear using the round challenge + inner_layer.numerators.bind_least_significant_variable(round_challenge); + inner_layer.denominators.bind_least_significant_variable(round_challenge); + }); + + // update the scaling up factor + scaling_up_factor *= + round_challenge * alpha_i + (E::ONE - round_challenge) * (E::ONE - alpha_i); + + // compute the new reduced round claim + batched_claim_across_circuits = compressed_round_poly + .evaluate_using_claim(&batched_claim_across_circuits, &round_challenge); + + round_proofs.push(round_proof); + challenges.push(round_challenge); + } + + let mut openings = Vec::with_capacity(inner_layers.len()); + for inner_layer in inner_layers.iter_mut() { + let p = inner_layer.numerators.evaluations(); + let q = inner_layer.denominators.evaluations(); + openings.push(vec![p[0], p[1], q[0], q[1]]) + } + + Ok(SumCheckProof { + openings_claim: FinalOpeningClaim { eval_point: challenges, openings }, + round_proofs, + }) +} + +#[allow(clippy::too_many_arguments)] +pub fn sumcheck_prove_plain_batched_serial< + E: FieldElement, + H: ElementHasher, +>( + claims: &[E], + gkr_point: &[E], + r_batch: E, + mut inner_layers: Vec>, + eq: &mut MultiLinearPoly, + tensored_batching_randomness: &[E], + transcript: &mut impl RandomCoin, +) -> Result, SumCheckProverError> { + let mut round_proofs = vec![]; + + let mut challenges = vec![]; + + let mut batched_claim_across_circuits = claims + .iter() + .zip(tensored_batching_randomness.iter()) + .fold(E::ZERO, |acc, (&claim_for_circuit_i, &randomness_for_circuit_i)| { + acc + claim_for_circuit_i * randomness_for_circuit_i + }); - let p0_delta = p0[2 * i + 1] - p0[2 * i]; - let p1_delta = p1[2 * i + 1] - p1[2 * i]; - let q0_delta = q0[2 * i + 1] - q0[2 * i]; - let q1_delta = q1[2 * i + 1] - q1[2 * i]; - let eq_delta = eq[2 * i + 1] - eq[2 * i]; - - let mut p0_eval_at_x = p0[2 * i + 1] + p0_delta; - let mut p1_eval_at_x = p1[2 * i + 1] + p1_delta; - let mut q0_eval_at_x = q0[2 * i + 1] + q0_delta; - let mut q1_eval_at_x = q1[2 * i + 1] + q1_delta; - let mut eq_evx = eq[2 * i + 1] + eq_delta; + let scaling_down_factors = compute_scaling_down_factors(gkr_point); + let mut scaling_up_factor = E::ONE; + + let num_sum_check_rounds = inner_layers[0].numerators.num_variables() - 1; + + for i in 0..num_sum_check_rounds { + let mut all_round_poly_eval_at_0 = E::ZERO; + let mut all_round_poly_eval_at_2 = E::ZERO; + let len = inner_layers[0].numerators.num_evaluations() / 4; + + for (inner_layer, batching_randomness) in + inner_layers.iter().zip(tensored_batching_randomness) + { + let (round_poly_eval_at_0, round_poly_eval_at_2) = + (0..len).fold((E::ZERO, E::ZERO), |(acc_point_0, acc_point_2), k| { + let p0_i_0 = inner_layer.numerators[2 * k]; + let p0_i_1 = inner_layer.numerators[2 * k + 1]; + let p1_i_0 = inner_layer.numerators[2 * (k + len)]; + let p1_i_1 = inner_layer.numerators[2 * (k + len) + 1]; + let q0_i_0 = inner_layer.denominators[2 * k]; + let q0_i_1 = inner_layer.denominators[2 * k + 1]; + let q1_i_0 = inner_layer.denominators[2 * (k + len)]; + let q1_i_1 = inner_layer.denominators[2 * (k + len) + 1]; + let round_poly_eval_at_0 = + comb_func(p0_i_0, p0_i_1, q0_i_0, q0_i_1, eq[k], r_batch); + + let p0_delta = p1_i_0 - p0_i_0; + let p1_delta = p1_i_1 - p0_i_1; + let q0_delta = q1_i_0 - q0_i_0; + let q1_delta = q1_i_1 - q0_i_1; + + let p0_eval_at_2 = p1_i_0 + p0_delta; + let p1_eval_at_2 = p1_i_1 + p1_delta; + let q0_eval_at_2 = q1_i_0 + q0_delta; + let q1_eval_at_2 = q1_i_1 + q1_delta; let round_poly_eval_at_2 = comb_func( - p0_eval_at_x, - p1_eval_at_x, - q0_eval_at_x, - q1_eval_at_x, - eq_evx, + p0_eval_at_2, + p1_eval_at_2, + q0_eval_at_2, + q1_eval_at_2, + eq[k], r_batch, ); - p0_eval_at_x += p0_delta; - p1_eval_at_x += p1_delta; - q0_eval_at_x += q0_delta; - q1_eval_at_x += q1_delta; - eq_evx += eq_delta; - let round_poly_eval_at_3 = comb_func( - p0_eval_at_x, - p1_eval_at_x, - q0_eval_at_x, - q1_eval_at_x, - eq_evx, - r_batch, - ); + (round_poly_eval_at_0 + acc_point_0, round_poly_eval_at_2 + acc_point_2) + }); - (round_poly_eval_at_1 + a, round_poly_eval_at_2 + b, round_poly_eval_at_3 + c) - }, - ) - .reduce( - || (E::ZERO, E::ZERO, E::ZERO), - |(a0, b0, c0), (a1, b1, c1)| (a0 + a1, b0 + b1, c0 + c1), - ); + all_round_poly_eval_at_0 += round_poly_eval_at_0 * *batching_randomness; + all_round_poly_eval_at_2 += round_poly_eval_at_2 * *batching_randomness; + } - let evals = smallvec![round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3]; - let compressed_round_poly_evals = CompressedUnivariatePolyEvals(evals); - let compressed_round_poly = compressed_round_poly_evals.to_poly(claim); + let alpha_i = gkr_point[i]; + let compressed_round_poly = to_coefficients( + &mut [all_round_poly_eval_at_0, all_round_poly_eval_at_2], + batched_claim_across_circuits, + alpha_i, + scaling_down_factors[i], + scaling_up_factor, + ); // reseed with the s_i polynomial transcript.reseed(H::hash_elements(&compressed_round_poly.0)); @@ -192,25 +414,32 @@ pub fn sumcheck_prove_plain CompressedUnivariatePoly { // evaluate polynom::eval(&complete_coefficients, *challenge) } -} - -impl Serializable for CompressedUnivariatePoly { - fn write_into(&self, target: &mut W) { - let vector: Vec = self.0.clone().into_vec(); - vector.write_into(target); - } -} -impl Deserializable for CompressedUnivariatePoly -where - E: FieldElement, -{ - fn read_from(source: &mut R) -> Result { - let vector: Vec = Vec::::read_from(source)?; - Ok(Self(vector.into())) - } -} + /// Given the evaluations of a polynomial over the set $0, 1, \cdots, d - 1$ and a `root` not in + /// the interpolation set, computes its coefficients. + pub fn interpolate_equidistant_points(ys: &[E], root: E) -> CompressedUnivariatePoly { + // we factor out the term `(x - r)` where `r` is the root + let quotient: Vec = (0..ys.len()).map(|i| E::from(i as u32) - root).collect(); + let quotient_inv = batch_inversion("ient); + let mut ys: Vec = ys.iter().zip(quotient_inv.iter()).map(|(&y, &q)| y * q).collect(); -/// The evaluations of a univariate polynomial of degree n at 0, 1, ..., n with the evaluation at 0 -/// omitted. -/// -/// This compressed representation is useful during the sum-check protocol as the full uncompressed -/// representation can be recovered from the compressed one and the current sum-check round claim. -#[derive(Clone, Debug)] -pub struct CompressedUnivariatePolyEvals(pub(crate) SmallVec<[E; MAX_POLY_SIZE]>); + // the zeroth coefficient can be recovered immediately + let c0 = ys.remove(0); -impl CompressedUnivariatePolyEvals { - /// Gives the coefficient representation of a polynomial represented in evaluation form. - /// - /// Since the evaluation at 0 is omitted, we need to use the round claim to recover - /// the evaluation at 0 using the identity $p(0) + p(1) = claim$. - /// Now, we have that for any polynomial $p(x) = c0 + c1 * x + ... + c_{n-1} * x^{n - 1}$: - /// - /// 1. $p(0) = c0$. - /// 2. $p(x) = c0 + x * q(x) where q(x) = c1 + ... + c_{n-1} * x^{n - 2}$. - /// - /// This means that we can compute the evaluations of q at 1, ..., n - 1 using the evaluations - /// of p and thus reduce by 1 the size of the interpolation problem. - /// Once the coefficient of q are recovered, the c0 coefficient is appended to these and this - /// is precisely the coefficient representation of the original polynomial q. - /// Note that the coefficient of the linear term is removed as this coefficient can be recovered - /// from the remaining coefficients, again, using the round claim using the relation - /// $2 * c0 + c1 + ... c_{n - 1} = claim$. - pub fn to_poly(&self, round_claim: E) -> CompressedUnivariatePoly { - // construct the vector of interpolation points 1, ..., n - let n_minus_1 = self.0.len(); + // build the interpolation set + let n_minus_1 = ys.len(); let points = (1..=n_minus_1 as u32).map(E::BaseField::from).collect::>(); // construct their inverses. These will be needed for computing the evaluations - // of the q polynomial as well as for doing the interpolation on q + // of the q polynomial as well as for doing the interpolation on q where q is + // defined as $p(x) = c0 + x * q(x) where q(x) = c1 + ... + c_{n-1} * x^{n - 2}$ let points_inv = batch_inversion(&points); - // compute the zeroth coefficient - let c0 = round_claim - self.0[0]; - // compute the evaluations of q - let q_evals: Vec = self - .0 + let q_evals: Vec = ys .iter() .enumerate() .map(|(i, evals)| (*evals - c0).mul_base(points_inv[i])) @@ -118,11 +82,34 @@ impl CompressedUnivariatePolyEvals { // append c0 to the coefficients of q to get the coefficients of p. The linear term // coefficient is removed as this can be recovered from the other coefficients using // the reduced claim. - let mut coefficients = SmallVec::with_capacity(self.0.len() + 1); + let mut coefficients = SmallVec::<[E; MAX_POLY_SIZE]>::with_capacity(ys.len() + 1); coefficients.push(c0); - coefficients.extend_from_slice(&q_coefs[1..]); + coefficients.extend_from_slice(&q_coefs[..]); + + // multiply back the factor `(x - r)` + let mut p_coefficients = polynom::mul(&coefficients, &[-root, E::ONE]); + + // remove the linear factor as it can be recovered from the `claim` and the other factors + p_coefficients.remove(1); + + CompressedUnivariatePoly(p_coefficients.into()) + } +} - CompressedUnivariatePoly(coefficients) +impl Serializable for CompressedUnivariatePoly { + fn write_into(&self, target: &mut W) { + let vector: Vec = self.0.clone().into_vec(); + vector.write_into(target); + } +} + +impl Deserializable for CompressedUnivariatePoly +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + let vector: Vec = Vec::::read_from(source)?; + Ok(Self(vector.into())) } } @@ -259,22 +246,18 @@ fn test_poly_partial() { use math::fields::f64::BaseElement; let degree = 1000; - let mut points: Vec = vec![BaseElement::ZERO; degree]; - points - .iter_mut() - .enumerate() - .for_each(|(i, node)| *node = BaseElement::from(i as u32)); + // compute the claim let p: Vec = rand_utils::rand_vector(degree); - let evals = polynom::eval_many(&p, &points); - - let mut partial_evals = evals.clone(); - partial_evals.remove(0); - - let partial_poly = CompressedUnivariatePolyEvals(partial_evals.into()); + let evals = polynom::eval_many(&p, &[BaseElement::ZERO, BaseElement::ONE]); let claim = evals[0] + evals[1]; - let poly_coeff = partial_poly.to_poly(claim); + // build compressed polynomial + let mut poly_coeff = p.clone(); + poly_coeff.remove(1); + let poly_coeff = CompressedUnivariatePoly(poly_coeff.into()); + + // generate random challenge let r = rand_utils::rand_vector(1); assert_eq!(polynom::eval(&p, r[0]), poly_coeff.evaluate_using_claim(&claim, &r[0])) diff --git a/sumcheck/src/verifier/mod.rs b/sumcheck/src/verifier/mod.rs index 900be4c86..4f59ec67a 100644 --- a/sumcheck/src/verifier/mod.rs +++ b/sumcheck/src/verifier/mod.rs @@ -22,31 +22,50 @@ pub fn verify_sum_check_intermediate_layers< >( proof: &SumCheckProof, gkr_eval_point: &[E], - claim: (E, E), + claims: &[(E, E)], + tensored_circuit_batching_randomness: &[E], transcript: &mut impl RandomCoin, ) -> Result, SumCheckVerifierError> { // generate challenge to batch sum-checks - transcript.reseed(H::hash_elements(&[claim.0, claim.1])); + let mut concatenated_claims = Vec::with_capacity(claims.len() * 2); + for claim in claims { + concatenated_claims.extend_from_slice(&[claim.0, claim.1]); + } + transcript.reseed(H::hash_elements(&concatenated_claims)); + let r_batch: E = transcript .draw() .map_err(|_| SumCheckVerifierError::FailedToGenerateChallenge)?; - // compute the claim for the batched sum-check - let reduced_claim = claim.0 + claim.1 * r_batch; + let mut batched_claims = vec![]; + for claim in claims { + let claim = claim.0 + claim.1 * r_batch; + batched_claims.push(claim) + } + let mut full_claim = E::ZERO; + for (circuit_id, claim) in batched_claims.iter().enumerate() { + full_claim += *claim * tensored_circuit_batching_randomness[circuit_id] + } let SumCheckProof { openings_claim, round_proofs } = proof; - let final_round_claim = verify_rounds(reduced_claim, round_proofs, transcript)?; + let final_round_claim = verify_rounds(full_claim, round_proofs, transcript)?; assert_eq!(openings_claim.eval_point, final_round_claim.eval_point); - let p0 = openings_claim.openings[0]; - let p1 = openings_claim.openings[1]; - let q0 = openings_claim.openings[2]; - let q1 = openings_claim.openings[3]; + let mut eval_batched_circuits = E::ZERO; + let eq = EqFunction::new(gkr_eval_point.into()).evaluate(&openings_claim.eval_point.clone()); + for (circuit_idx, openings) in openings_claim.openings.iter().enumerate() { + let p0 = openings[0]; + let p1 = openings[1]; + let q0 = openings[2]; + let q1 = openings[3]; - let eq = EqFunction::new(gkr_eval_point.into()).evaluate(&openings_claim.eval_point); + eval_batched_circuits += comb_func(p0, p1, q0, q1, eq, r_batch) + * tensored_circuit_batching_randomness[circuit_idx] + } - if comb_func(p0, p1, q0, q1, eq, r_batch) != final_round_claim.claim { + if eval_batched_circuits != final_round_claim.claim { + assert_eq!(1, 0); return Err(SumCheckVerifierError::FinalEvaluationCheckFailed); } @@ -63,53 +82,89 @@ pub fn verify_sum_check_input_layer, log_up_randomness: Vec, gkr_eval_point: &[E], - claim: (E, E), + claim: Vec<(E, E)>, + tensored_circuit_batching_randomness: &[E], transcript: &mut impl RandomCoin, ) -> Result, SumCheckVerifierError> { + let mut all_claims_concatenated = Vec::with_capacity(claim.len()); + for claimed_evaluation in claim.iter() { + all_claims_concatenated.extend_from_slice(&[claimed_evaluation.0, claimed_evaluation.1]); + } + transcript.reseed(H::hash_elements(&all_claims_concatenated)); + // generate challenge to batch sum-checks - transcript.reseed(H::hash_elements(&[claim.0, claim.1])); let r_batch: E = transcript .draw() .map_err(|_| SumCheckVerifierError::FailedToGenerateChallenge)?; - // compute the claim for the batched sum-check - let reduced_claim = claim.0 + claim.1 * r_batch; + let mut batched_claims = vec![]; + for claimed_evaluation in claim.iter() { + let claim = claimed_evaluation.0 + claimed_evaluation.1 * r_batch; + batched_claims.push(claim) + } + + let mut full_claim = E::ZERO; + for (circuit_id, claim) in batched_claims.iter().enumerate() { + full_claim += *claim * tensored_circuit_batching_randomness[circuit_id] + } // verify the sum-check proof let SumCheckRoundClaim { eval_point, claim } = - verify_rounds(reduced_claim, &proof.0.round_proofs, transcript)?; + verify_rounds(full_claim, &proof.0.round_proofs, transcript)?; // execute the final evaluation check if proof.0.openings_claim.eval_point != eval_point { return Err(SumCheckVerifierError::WrongOpeningPoint); } - let mut numerators = vec![E::ZERO; evaluator.get_num_fractions()]; - let mut denominators = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut numerators_zero = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut numerators_one = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut denominators_zero = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut denominators_one = vec![E::ZERO; evaluator.get_num_fractions()]; - let periodic_columns = evaluator.build_periodic_values(); - let periodic_columns_evaluations = + let trace_len = 1 << eval_point.len(); + let periodic_columns = evaluator.build_periodic_values(trace_len); + let (periodic_columns_evaluations_zero, periodic_columns_evaluations_one) = evaluate_periodic_columns_at(periodic_columns, &proof.0.openings_claim.eval_point); + let mut at_zero = Vec::with_capacity(proof.0.openings_claim.openings[0].len()); + let mut at_one = Vec::with_capacity(proof.0.openings_claim.openings[0].len()); + for ml in proof.0.openings_claim.openings[0].chunks(2) { + at_zero.push(ml[0]); + at_one.push(ml[1]); + } + evaluator.evaluate_query( - &proof.0.openings_claim.openings, - &periodic_columns_evaluations, + &at_zero, + &periodic_columns_evaluations_zero, &log_up_randomness, - &mut numerators, - &mut denominators, + &mut numerators_zero, + &mut denominators_zero, + ); + evaluator.evaluate_query( + &at_one, + &periodic_columns_evaluations_one, + &log_up_randomness, + &mut numerators_one, + &mut denominators_one, ); - let mu = evaluator.get_num_fractions().trailing_zeros() - 1; - let (evaluation_point_mu, evaluation_point_nu) = gkr_eval_point.split_at(mu as usize); - - let eq_mu = EqFunction::new(evaluation_point_mu.into()).evaluations(); - let eq_nu = EqFunction::new(evaluation_point_nu.into()); + let eq_nu = EqFunction::new(gkr_eval_point.into()); let eq_nu_eval = eq_nu.evaluate(&proof.0.openings_claim.eval_point); - let expected_evaluation = - evaluate_composition_poly(&eq_mu, &numerators, &denominators, eq_nu_eval, r_batch); + + let expected_evaluation = evaluate_composition_poly( + tensored_circuit_batching_randomness, + &numerators_zero, + &denominators_zero, + &numerators_one, + &denominators_one, + eq_nu_eval, + r_batch, + ); if expected_evaluation != claim { + assert_eq!(1, 0); Err(SumCheckVerifierError::FinalEvaluationCheckFailed) } else { Ok(proof.0.openings_claim.clone()) @@ -161,15 +216,24 @@ pub enum SumCheckVerifierError { fn evaluate_periodic_columns_at( periodic_columns: PeriodicTable, eval_point: &[E], -) -> Vec { - let mut evaluations = vec![]; +) -> (Vec, Vec) { + let mut eval_point_zero = eval_point.to_vec(); + let mut eval_point_one = eval_point.to_vec(); + eval_point_zero.push(E::ZERO); + eval_point_one.push(E::ONE); + + let mut evaluations_zero = vec![]; + let mut evaluations_one = vec![]; for col in periodic_columns.table() { let ml = MultiLinearPoly::from_evaluations(col.to_vec()); let num_variables = ml.num_variables(); - let point = &eval_point[..num_variables]; + let point_zero = &eval_point_zero[eval_point_zero.len() - num_variables..]; + let point_one = &eval_point_one[eval_point_one.len() - num_variables..]; - let evaluation = ml.evaluate(point); - evaluations.push(evaluation) + let evaluation_zero = ml.evaluate(point_zero); + evaluations_zero.push(evaluation_zero); + let evaluation_one = ml.evaluate(point_one); + evaluations_one.push(evaluation_one) } - evaluations + (evaluations_zero, evaluations_one) } diff --git a/utils/core/src/iterators.rs b/utils/core/src/iterators.rs index f978acd40..cf8328ae9 100644 --- a/utils/core/src/iterators.rs +++ b/utils/core/src/iterators.rs @@ -133,3 +133,21 @@ macro_rules! chunks { result }}; } + +/// Returns either a regular or a parallel mutable iterator over at most `chunk_size` elements +/// depending on whether `concurrent` feature is enabled. +/// +/// When `concurrent` feature is enabled, creates a parallel iterator; otherwise, creates a +/// regular iterator. +#[macro_export] +macro_rules! chunks_mut { + ($e: expr, $chunk_size: expr) => {{ + #[cfg(feature = "concurrent")] + let result = $e.par_chunks_mut($chunk_size); + + #[cfg(not(feature = "concurrent"))] + let result = $e.chunks_mut($chunk_size); + + result + }}; +} diff --git a/verifier/Cargo.toml b/verifier/Cargo.toml index 0c07d493c..b60ab0347 100644 --- a/verifier/Cargo.toml +++ b/verifier/Cargo.toml @@ -27,6 +27,7 @@ math = { version = "0.9", path = "../math", package = "winter-math", default-fea sumcheck = { version = "0.1", path = "../sumcheck", package = "winter-sumcheck", default-features = false } thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false } utils = { version = "0.9", path = "../utils/core", package = "winter-utils", default-features = false } +libc-print = "0.1.23" # Allow math in docs [package.metadata.docs.rs] diff --git a/verifier/src/logup_gkr/mod.rs b/verifier/src/logup_gkr/mod.rs index e317e0ab1..00ff09341 100644 --- a/verifier/src/logup_gkr/mod.rs +++ b/verifier/src/logup_gkr/mod.rs @@ -4,7 +4,7 @@ use air::{Air, LogUpGkrEvaluator}; use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; use sumcheck::{ - verify_sum_check_input_layer, verify_sum_check_intermediate_layers, CircuitOutput, + verify_sum_check_input_layer, verify_sum_check_intermediate_layers, CircuitOutput, EqFunction, FinalOpeningClaim, GkrCircuitProof, SumCheckVerifierError, }; @@ -34,32 +34,66 @@ pub fn verify_gkr< } = proof; let CircuitOutput { numerators, denominators } = circuit_outputs; - let p0 = numerators.evaluations()[0]; - let p1 = numerators.evaluations()[1]; - let q0 = denominators.evaluations()[0]; - let q1 = denominators.evaluations()[1]; - - // make sure that both denominators are not equal to E::ZERO - if q0 == E::ZERO || q1 == E::ZERO { - return Err(VerifierError::ZeroOutputDenominator); - } - - // check that the output matches the expected `claim` let claim = evaluator.compute_claim(pub_inputs, &logup_randomness); - if (p0 * q1 + p1 * q0) / (q0 * q1) != claim { + let mut total_evaluations = Vec::with_capacity(numerators.len() * 4); + let mut num_acc = E::ZERO; + let mut den_acc = E::ONE; + for (nums, dens) in numerators.iter().zip(denominators.iter()) { + total_evaluations.extend_from_slice(nums.evaluations()); + total_evaluations.extend_from_slice(dens.evaluations()); + + let p0 = nums.evaluations()[0]; + let p1 = nums.evaluations()[1]; + let q0 = dens.evaluations()[0]; + let q1 = dens.evaluations()[1]; + + // make sure that both denominators are not equal to E::ZERO + if q0 == E::ZERO || q1 == E::ZERO { + return Err(VerifierError::ZeroOutputDenominator); + } + + let cur_num = p0 * q1 + p1 * q0; + let cur_den = q0 * q1; + + let new_num = num_acc * cur_den + den_acc * cur_num; + let new_den = den_acc * cur_den; + num_acc = new_num; + den_acc = new_den; + } + if num_acc / den_acc != claim { return Err(VerifierError::MismatchingCircuitOutput); } + transcript.reseed(H::hash_elements(&total_evaluations)); // generate the random challenge to reduce two claims into a single claim - let mut evaluations = numerators.evaluations().to_vec(); - evaluations.extend_from_slice(denominators.evaluations()); - transcript.reseed(H::hash_elements(&evaluations)); let r = transcript.draw().map_err(|_| VerifierError::FailedToGenerateChallenge)?; // reduce the claim - let p_r = p0 + r * (p1 - p0); - let q_r = q0 + r * (q1 - q0); - let mut reduced_claim = (p_r, q_r); + let mut reduced_claims = vec![]; + for (nums, dens) in numerators.iter().zip(denominators.iter()) { + let p0 = nums.evaluations()[0]; + let p1 = nums.evaluations()[1]; + let q0 = dens.evaluations()[0]; + let q1 = dens.evaluations()[1]; + // reduce the claim + let p_r = p0 + r * (p1 - p0); + let q_r = q0 + r * (q1 - q0); + let reduced_claim = (p_r, q_r); + reduced_claims.push(reduced_claim) + } + + let num_circuits = reduced_claims.len(); + let log_num_circuits = num_circuits.next_power_of_two().ilog2(); + + let mut circuit_batching_randomness: Vec = vec![]; + + for _ in 0..log_num_circuits { + let batching_r = transcript.draw().map_err(|_| VerifierError::FailedToGenerateChallenge)?; + circuit_batching_randomness.push(batching_r); + } + + let tensored_circuit_batching_randomness = + EqFunction::new(circuit_batching_randomness.into()).evaluations(); // verify all GKR layers but for the last one let num_layers = before_final_layer_proofs.proof.len(); @@ -68,24 +102,32 @@ pub fn verify_gkr< let FinalOpeningClaim { eval_point, openings } = verify_sum_check_intermediate_layers( &before_final_layer_proofs.proof[i], &evaluation_point, - reduced_claim, + &reduced_claims, + &tensored_circuit_batching_randomness, transcript, )?; // generate the random challenge to reduce two claims into a single claim - transcript.reseed(H::hash_elements(&openings)); + let mut total_openings = Vec::with_capacity(openings.len() * 4); + for opening_circuit_i in openings.iter() { + total_openings.extend_from_slice(opening_circuit_i); + } + transcript.reseed(H::hash_elements(&total_openings)); let r_layer = transcript.draw().map_err(|_| VerifierError::FailedToGenerateChallenge)?; - let p0 = openings[0]; - let p1 = openings[1]; - let q0 = openings[2]; - let q1 = openings[3]; - reduced_claim = (p0 + r_layer * (p1 - p0), q0 + r_layer * (q1 - q0)); + for (circuit_id, ops) in openings.iter().enumerate() { + let p0 = ops[0]; + let p1 = ops[1]; + let q0 = ops[2]; + let q1 = ops[3]; + + let reduced_claim = (p0 + r_layer * (p1 - p0), q0 + r_layer * (q1 - q0)); + reduced_claims[circuit_id] = reduced_claim; + } // collect the randomness used for the current layer - let rand_sumcheck = eval_point; - let mut ext = vec![r_layer]; - ext.extend_from_slice(&rand_sumcheck); + let mut ext = eval_point.clone(); + ext.push(r_layer); evaluation_point = ext; } @@ -96,7 +138,8 @@ pub fn verify_gkr< final_layer_proof, logup_randomness, &evaluation_point, - reduced_claim, + reduced_claims, + &tensored_circuit_batching_randomness, transcript, ) .map_err(VerifierError::FailedToVerifySumCheck) diff --git a/winterfell/src/lib.rs b/winterfell/src/lib.rs index 621796864..291e0f27e 100644 --- a/winterfell/src/lib.rs +++ b/winterfell/src/lib.rs @@ -590,7 +590,7 @@ #[cfg(test)] extern crate std; -pub use air::{AuxRandElements, LogUpGkrEvaluator}; +pub use air::{AuxRandElements, LogUpGkrEvaluator, LogUpGkrOracle}; pub use prover::{ crypto, iterators, math, matrix, Air, AirContext, Assertion, AuxTraceWithMetadata, BoundaryConstraint, BoundaryConstraintGroup, CompositionPolyTrace, diff --git a/winterfell/src/tests/logup_gkr_periodic.rs b/winterfell/src/tests/logup_gkr_periodic.rs index 849cbbd5d..9b64b24d5 100644 --- a/winterfell/src/tests/logup_gkr_periodic.rs +++ b/winterfell/src/tests/logup_gkr_periodic.rs @@ -23,7 +23,7 @@ use crate::{ #[test] fn test_logup_gkr_periodic() { let aux_trace_width = 1; - let trace = LogUpGkrPeriodic::new(2_usize.pow(12), aux_trace_width); + let trace = LogUpGkrPeriodic::new(2_usize.pow(13), aux_trace_width); let prover = LogUpGkrPeriodicProver::new(aux_trace_width); let proof = prover.prove(trace).unwrap(); diff --git a/winterfell/src/tests/logup_gkr_simple.rs b/winterfell/src/tests/logup_gkr_simple.rs index 6c814c948..63a51d77b 100644 --- a/winterfell/src/tests/logup_gkr_simple.rs +++ b/winterfell/src/tests/logup_gkr_simple.rs @@ -23,7 +23,7 @@ use crate::{ #[test] fn test_logup_gkr() { let aux_trace_width = 1; - let trace = LogUpGkrSimple::new(2_usize.pow(7), aux_trace_width); + let trace = LogUpGkrSimple::new(2_usize.pow(13), aux_trace_width); let prover = LogUpGkrSimpleProver::new(aux_trace_width); let proof = prover.prove(trace).unwrap(); @@ -55,8 +55,10 @@ impl LogUpGkrSimple { (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); let mut multiplicity: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); - multiplicity[0] = BaseElement::new(3 * trace_len as u64 - 3 * 4); - multiplicity[1] = BaseElement::new(3 * 4); + multiplicity[1] = BaseElement::new(3 * trace_len as u64 - 3 * 4 - 1); + multiplicity[2] = BaseElement::new(3 * 4); + + multiplicity[trace_len - 3] = BaseElement::ONE; let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); @@ -75,6 +77,7 @@ impl LogUpGkrSimple { for i in 0..4 { values_2[i + 4] = BaseElement::ONE; } + values_1[trace_len - 3] = BaseElement::new(trace_len as u64 - 4); Self { main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), @@ -189,27 +192,12 @@ pub struct PlainLogUpGkrEval { impl PlainLogUpGkrEval { pub fn new() -> Self { let committed_0 = LogUpGkrOracle::CurrentRow(0); - let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_1 = LogUpGkrOracle::NextRow(1); let committed_2 = LogUpGkrOracle::CurrentRow(2); let committed_3 = LogUpGkrOracle::CurrentRow(3); let committed_4 = LogUpGkrOracle::CurrentRow(4); - let committed_0_next_row = LogUpGkrOracle::NextRow(0); - let committed_1_next_row = LogUpGkrOracle::NextRow(1); - let committed_2_next_row = LogUpGkrOracle::NextRow(2); - let committed_3_next_row = LogUpGkrOracle::NextRow(3); - let committed_4_next_row = LogUpGkrOracle::NextRow(4); - let oracles = vec![ - committed_0, - committed_1, - committed_2, - committed_3, - committed_4, - committed_0_next_row, - committed_1_next_row, - committed_2_next_row, - committed_3_next_row, - committed_4_next_row, - ]; + + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; Self { oracles, _field: PhantomData } } } @@ -228,7 +216,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { } fn get_num_fractions(&self) -> usize { - 4 + 5 } fn max_degree(&self) -> usize { @@ -240,16 +228,10 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { E: FieldElement, { query[0] = frame.current()[0]; - query[1] = frame.current()[1]; + query[1] = frame.next()[1]; query[2] = frame.current()[2]; query[3] = frame.current()[3]; query[4] = frame.current()[4]; - - query[5] = frame.next()[0]; - query[6] = frame.next()[1]; - query[7] = frame.next()[2]; - query[8] = frame.next()[3]; - query[9] = frame.next()[4]; } fn evaluate_query( @@ -263,18 +245,17 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { F: FieldElement, E: FieldElement + ExtensionOf, { - assert_eq!(numerator.len(), 4); - assert_eq!(denominator.len(), 4); - assert_eq!(query.len(), 10); numerator[0] = E::from(query[1]); numerator[1] = E::ONE; numerator[2] = E::ONE; numerator[3] = E::ONE; + numerator[4] = E::ZERO; denominator[0] = rand_values[0] - E::from(query[0]); denominator[1] = -(rand_values[0] - E::from(query[2])); denominator[2] = -(rand_values[0] - E::from(query[3])); denominator[3] = -(rand_values[0] - E::from(query[4])); + denominator[4] = -(rand_values[0] - E::from(query[4])); } fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E @@ -296,7 +277,7 @@ impl LogUpGkrSimpleProver { fn new(aux_trace_width: usize) -> Self { Self { aux_trace_width, - options: ProofOptions::new(1, 8, 0, FieldExtension::Quadratic, 2, 1), + options: ProofOptions::new(1, 8, 0, FieldExtension::None, 2, 1), } } }