From c317f96be15d8b8e7db5fa9b6d8381e36146e9c7 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 23 Sep 2024 11:09:06 +0200 Subject: [PATCH 01/44] wip: change to bitreverse representation of MLE. Periodic columns example still failing --- air/src/air/logup_gkr/lagrange/transition.rs | 8 ++- air/src/air/logup_gkr/mod.rs | 2 +- prover/src/logup_gkr/prover.rs | 4 +- prover/src/trace/mod.rs | 3 + sumcheck/src/multilinear.rs | 18 +++--- sumcheck/src/prover/high_degree.rs | 29 +++++---- sumcheck/src/prover/plain.rs | 64 ++++++++++---------- sumcheck/src/verifier/mod.rs | 4 +- verifier/src/logup_gkr/mod.rs | 5 +- 9 files changed, 74 insertions(+), 63 deletions(-) diff --git a/air/src/air/logup_gkr/lagrange/transition.rs b/air/src/air/logup_gkr/lagrange/transition.rs index 5f5b110e6..51654b2c2 100644 --- a/air/src/air/logup_gkr/lagrange/transition.rs +++ b/air/src/air/logup_gkr/lagrange/transition.rs @@ -62,6 +62,9 @@ impl LagrangeKernelTransitionConstraints { let c = lagrange_kernel_column_frame; let v = c.num_rows() - 1; let r = lagrange_kernel_rand_elements; + // TODO: avoid reverse() + let mut r = r.to_vec(); + r.reverse(); let k = constraint_idx + 1; (r[v - k] * c[0]) - ((E::ONE - r[v - k]) * c[v - k + 1]) @@ -129,7 +132,10 @@ impl LagrangeKernelTransitionConstraints { let c = lagrange_kernel_column_frame; let v = c.num_rows() - 1; - let r = lagrange_kernel_rand_elements; + let r = lagrange_kernel_rand_elements; + // TODO: avoid reverse() + let mut r = r.to_vec(); + r.reverse(); for k in 1..v + 1 { transition_evals[k - 1] = (r[v - k] * c[0]) - ((E::ONE - r[v - k]) * c[v - k + 1]); diff --git a/air/src/air/logup_gkr/mod.rs b/air/src/air/logup_gkr/mod.rs index d3e198912..e2719c626 100644 --- a/air/src/air/logup_gkr/mod.rs +++ b/air/src/air/logup_gkr/mod.rs @@ -296,7 +296,7 @@ where if col.len() > 1 { let num_evals = col.len() >> 1; for i in 0..num_evals { - col[i] = col[i << 1] + round_challenge * (col[(i << 1) + 1] - col[i << 1]); + col[i] = col[i] + round_challenge * (col[i + num_evals] - col[i]); } col.truncate(num_evals) } diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index f258d0845..e39289217 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -231,8 +231,8 @@ fn prove_intermediate_layers< }; // collect the randomness used for the current layer - let mut ext = vec![r_layer]; - ext.extend_from_slice(&proof.openings_claim.eval_point); + let mut ext = proof.openings_claim.eval_point.clone(); + ext.push(r_layer); evaluation_point = ext; layer_proofs.push(proof); diff --git a/prover/src/trace/mod.rs b/prover/src/trace/mod.rs index 08fb49a2a..4aeca1177 100644 --- a/prover/src/trace/mod.rs +++ b/prover/src/trace/mod.rs @@ -244,6 +244,9 @@ pub trait Trace: Sized + Sync { let v = trace_length.ilog2() as usize; let gkr_data = aux_rand_elements.gkr_data().expect("should not be None"); let r = gkr_data.lagrange_kernel_rand_elements(); + // TODO: avoid reverse() + let mut r = r.to_vec(); + r.reverse(); // Loop over every Lagrange kernel constraint for constraint_idx in 1..v + 1 { diff --git a/sumcheck/src/multilinear.rs b/sumcheck/src/multilinear.rs index df6177914..49baabf57 100644 --- a/sumcheck/src/multilinear.rs +++ b/sumcheck/src/multilinear.rs @@ -79,9 +79,9 @@ impl MultiLinearPoly { // SAFETY: This loops over [0, evaluations.len()/2). The largest value for `i` is // `(evaluations.len() / 2) - 1`. Hence, the largest value for `(i<<1)` is // `evaluations.len() - 2`, and largest value for `(i<<1) + 1` is `evaluations.len() - 1`. - let evaluations_2i = unsafe { *self.evaluations.get_unchecked(i << 1) }; + let evaluations_2i = unsafe { *self.evaluations.get_unchecked(i) }; let evaluations_2i_plus_1 = - unsafe { *self.evaluations.get_unchecked((i << 1) + 1) }; + unsafe { *self.evaluations.get_unchecked(num_evals + i) }; self.evaluations[i] = evaluations_2i + round_challenge * (evaluations_2i_plus_1 - evaluations_2i); @@ -96,9 +96,9 @@ impl MultiLinearPoly { // SAFETY: This loops over [0, evaluations.len()/2). The largest value for `i` is // `(evaluations.len() / 2) - 1`. Hence, the largest value for `(i<<1)` is // `evaluations.len() - 2`, and largest value for `(i<<1) + 1` is `evaluations.len() - 1`. - let evaluations_2i = unsafe { *self.evaluations.get_unchecked(i << 1) }; + let evaluations_2i = unsafe { *self.evaluations.get_unchecked(i) }; let evaluations_2i_plus_1 = - unsafe { *self.evaluations.get_unchecked((i << 1) + 1) }; + unsafe { *self.evaluations.get_unchecked(num_evals + i) }; *ev = evaluations_2i + round_challenge * (evaluations_2i_plus_1 - evaluations_2i); }); @@ -107,7 +107,7 @@ impl MultiLinearPoly { } /// Given the multilinear polynomial $f(y_0, y_1, ..., y_{{\nu} - 1})$, returns two polynomials: - /// $f(0, y_1, ..., y_{{\nu} - 1})$ and $f(1, y_1, ..., y_{{\nu} - 1})$. + /// $f(y_0, y_1, ..., y_{{\nu} - 2}, 0)$ and $f(y_0, y_1, ..., y_{{\nu} - 2}, 1)$. pub fn project_least_significant_variable(mut self) -> (Self, Self) { let odds: Vec = self .evaluations @@ -284,7 +284,7 @@ fn compute_lagrange_basis_evals_at(query: &[E]) -> Vec { evals[0] = E::ONE; #[cfg(not(feature = "concurrent"))] let evals = { - for r_i in query.iter() { + for r_i in query.iter().rev() { let (left_evals, right_evals) = evals.split_at_mut(size); left_evals.iter_mut().zip(right_evals.iter_mut()).for_each(|(left, right)| { let factor = *left; @@ -299,7 +299,7 @@ fn compute_lagrange_basis_evals_at(query: &[E]) -> Vec { #[cfg(feature = "concurrent")] let evals = { - for r_i in query.iter() { + for r_i in query.iter().rev() { let (left_evals, right_evals) = evals.split_at_mut(size); left_evals .par_iter_mut() @@ -380,14 +380,14 @@ fn test_eq_function() { let r1 = rand_value(); let eq_function = EqFunction::new(smallvec![r0, r1]); - let expected = vec![(one - r0) * (one - r1), r0 * (one - r1), (one - r0) * r1, r0 * r1]; + let expected = vec![(one - r0) * (one - r1), (one - r0) * r1, r0 * (one - r1), r0 * r1]; assert_eq!(expected, eq_function.evaluations()); // Lagrange kernel evaluation is correct let q0 = rand_value(); let q1 = rand_value(); - let tensored_query = vec![(one - q0) * (one - q1), q0 * (one - q1), (one - q0) * q1, q0 * q1]; + let tensored_query = vec![(one - q0) * (one - q1), (one - q0) * q1, q0 * (one - q1), q0 * q1]; let expected = inner_product(&tensored_query, &eq_function.evaluations()); diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index 47be290d7..a9546baec 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -169,7 +169,9 @@ pub fn sum_check_prove_higher_degree< // split the evaluation point into two points of dimension mu and nu, respectively let mu = evaluator.get_num_fractions().trailing_zeros() - 1; - let (evaluation_point_mu, evaluation_point_nu) = evaluation_point.split_at(mu as usize); + let (evaluation_point_nu, evaluation_point_mu) = + evaluation_point.split_at(evaluation_point.len() - mu as usize); + let eq_mu = EqFunction::ml_at(evaluation_point_mu.into()).evaluations().to_vec(); let mut eq_nu = EqFunction::ml_at(evaluation_point_nu.into()); @@ -329,16 +331,17 @@ fn sumcheck_round( let mut total_evals = vec![E::ZERO; evaluator.max_degree()]; for (j, ml) in mls.iter().enumerate() { - evals_zero[j] = ml.evaluations()[2 * i]; - evals_one[j] = ml.evaluations()[2 * i + 1]; + evals_zero[j] = ml.evaluations()[i]; + evals_one[j] = ml.evaluations()[i + (1 << num_rounds)]; } - let eq_at_zero = eq_ml.evaluations()[2 * i]; - let eq_at_one = eq_ml.evaluations()[2 * i + 1]; + let eq_at_zero = eq_ml.evaluations()[i]; + let eq_at_one = eq_ml.evaluations()[i + (1 << num_rounds)]; // add evaluation of periodic columns - periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero); - periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_one); + periodic_table.fill_periodic_values_at(i, &mut evals_periodic_zero); + periodic_table + .fill_periodic_values_at(i + (1 << num_rounds), &mut evals_periodic_one); // compute the evaluation at 1 evaluator.evaluate_query( @@ -439,16 +442,16 @@ fn sumcheck_round( ), i| { for (j, ml) in mls.iter().enumerate() { - evals_zero[j] = ml.evaluations()[2 * i]; - evals_one[j] = ml.evaluations()[2 * i + 1]; + evals_zero[j] = ml.evaluations()[i]; + evals_one[j] = ml.evaluations()[i + (1 << num_rounds)]; } - let eq_at_zero = eq_ml.evaluations()[2 * i]; - let eq_at_one = eq_ml.evaluations()[2 * i + 1]; + let eq_at_zero = eq_ml.evaluations()[i]; + let eq_at_one = eq_ml.evaluations()[i + (1 << num_rounds)]; // add evaluation of periodic columns - periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero); - periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_one); + periodic_table.fill_periodic_values_at(i, &mut evals_periodic_zero); + periodic_table.fill_periodic_values_at(i + (1 << num_rounds), &mut evals_periodic_one); // compute the evaluation at 1 evaluator.evaluate_query( diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index e0092cf10..ac883ecbb 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -72,25 +72,25 @@ pub fn sumcheck_prove_plain Date: Wed, 25 Sep 2024 13:49:07 +0200 Subject: [PATCH 02/44] fix: parallel sum-check + periodic test failing --- air/src/air/logup_gkr/mod.rs | 20 +++++++------ prover/Cargo.toml | 2 +- prover/src/logup_gkr/mod.rs | 2 +- prover/src/logup_gkr/prover.rs | 2 +- sumcheck/src/prover/high_degree.rs | 16 +++++------ sumcheck/src/prover/plain.rs | 33 +++++++++++----------- sumcheck/src/verifier/mod.rs | 20 ++++++------- winterfell/src/tests/logup_gkr_periodic.rs | 2 +- 8 files changed, 50 insertions(+), 47 deletions(-) diff --git a/air/src/air/logup_gkr/mod.rs b/air/src/air/logup_gkr/mod.rs index e2719c626..628d442c2 100644 --- a/air/src/air/logup_gkr/mod.rs +++ b/air/src/air/logup_gkr/mod.rs @@ -156,7 +156,7 @@ pub trait LogUpGkrEvaluator: Clone + Sync { /// Returns the periodic values used in the LogUp-GKR statement, either as base field element /// during circuit evaluation or as extension field element during the run of sum-check for /// the input layer. - fn build_periodic_values(&self) -> PeriodicTable + fn build_periodic_values(&self, num_rows: usize) -> PeriodicTable where E: FieldElement, { @@ -166,7 +166,7 @@ pub trait LogUpGkrEvaluator: Clone + Sync { .map(|values| values.iter().map(|x| E::from(*x)).collect()) .collect(); - PeriodicTable { table } + PeriodicTable { table, num_rows } } } @@ -264,16 +264,17 @@ pub enum LogUpGkrOracle { #[derive(Clone, Debug, Default, PartialEq, PartialOrd, Eq, Ord)] pub struct PeriodicTable { pub table: Vec>, + pub num_rows: usize, } impl PeriodicTable where E: FieldElement, { - pub fn new(table: Vec>) -> Self { + pub fn new(table: Vec>, num_rows: usize) -> Self { let table = table.iter().map(|col| col.iter().map(|x| E::from(*x)).collect()).collect(); - Self { table } + Self { table, num_rows } } pub fn num_columns(&self) -> usize { @@ -294,12 +295,15 @@ where pub fn bind_least_significant_variable(&mut self, round_challenge: E) { for col in self.table.iter_mut() { if col.len() > 1 { - let num_evals = col.len() >> 1; - for i in 0..num_evals { - col[i] = col[i] + round_challenge * (col[i + num_evals] - col[i]); + if self.num_rows <= col.len() { + let num_evals = col.len() >> 1; + for i in 0..num_evals { + col[i] = col[i] + round_challenge * (col[i + num_evals] - col[i]); + } + col.truncate(num_evals) } - col.truncate(num_evals) } } + self.num_rows = self.num_rows / 2; } } diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 37e45c472..49f5f71cd 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -29,7 +29,7 @@ harness = false [features] async = ["maybe_async/async"] -concurrent = ["crypto/concurrent", "math/concurrent", "fri/concurrent", "utils/concurrent", "std"] +concurrent = ["crypto/concurrent", "math/concurrent", "fri/concurrent", "utils/concurrent", "sumcheck/concurrent", "std"] default = ["std"] std = ["air/std", "crypto/std", "fri/std", "math/std", "utils/std"] diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 016cd5218..7375c889a 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -116,7 +116,7 @@ impl EvaluatedCircuit { log_up_randomness: &[E], ) -> CircuitLayer { let num_fractions = evaluator.get_num_fractions(); - let periodic_values = evaluator.build_periodic_values(); + let periodic_values = evaluator.build_periodic_values(trace.main_segment().num_rows()); let mut input_layer_wires = unsafe { uninit_vector(trace.main_segment().num_rows() * num_fractions) }; diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index e39289217..90cd3d2ed 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -80,7 +80,7 @@ pub fn prove_gkr( let main_trace_mls = build_mls_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; // build the periodic table representing periodic columns as multi-linear extensions - let periodic_table = evaluator.build_periodic_values(); + let periodic_table = evaluator.build_periodic_values(main_trace.main_segment().num_rows()); // run the GKR prover for the input layer let final_layer_proof = prove_input_layer( diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index a9546baec..d00c01e59 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -241,7 +241,6 @@ pub fn sum_check_prove_higher_degree< // fold each multi-linear using the last random round challenge mls.iter_mut() .for_each(|ml| ml.bind_least_significant_variable(round_challenge)); - eq_nu.bind_least_significant_variable(round_challenge); let SumCheckRoundClaim { eval_point, claim: _claim } = reduce_claim(&round_proofs[num_rounds - 1], current_round_claim, round_challenge); @@ -420,11 +419,11 @@ fn sumcheck_round( vec![E::ZERO; num_periodic], vec![E::ZERO; num_periodic], vec![E::ZERO; num_periodic], - vec![E::ZERO; evaluator.max_degree()], vec![E::ZERO; evaluator.get_num_fractions()], vec![E::ZERO; evaluator.get_num_fractions()], vec![E::ZERO; num_mls], vec![E::ZERO; num_periodic], + vec![E::ZERO; evaluator.max_degree()], ) }, |( @@ -434,11 +433,11 @@ fn sumcheck_round( mut evals_periodic_zero, mut evals_periodic_one, mut evals_periodic_x, - mut poly_evals, mut numerators, mut denominators, mut deltas, mut deltas_periodic, + mut poly_evals, ), i| { for (j, ml) in mls.iter().enumerate() { @@ -451,7 +450,8 @@ fn sumcheck_round( // add evaluation of periodic columns periodic_table.fill_periodic_values_at(i, &mut evals_periodic_zero); - periodic_table.fill_periodic_values_at(i + (1 << num_rounds), &mut evals_periodic_one); + periodic_table + .fill_periodic_values_at(i + (1 << num_rounds), &mut evals_periodic_one); // compute the evaluation at 1 evaluator.evaluate_query( @@ -461,7 +461,7 @@ fn sumcheck_round( &mut numerators, &mut denominators, ); - poly_evals[0] = evaluate_composition_poly( + poly_evals[0] += evaluate_composition_poly( eq_mu, &numerators, &denominators, @@ -499,7 +499,7 @@ fn sumcheck_round( &mut numerators, &mut denominators, ); - *e = evaluate_composition_poly( + *e += evaluate_composition_poly( eq_mu, &numerators, &denominators, @@ -515,15 +515,15 @@ fn sumcheck_round( evals_periodic_zero, evals_periodic_one, evals_periodic_x, - poly_evals, numerators, denominators, deltas, deltas_periodic, + poly_evals, ) }, ) - .map(|(_, _, _, poly_evals, ..)| poly_evals) + .map(|(.., poly_evals)| poly_evals) .reduce( || vec![E::ZERO; evaluator.max_degree()], |mut acc, poly_eval| { diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index ac883ecbb..a9c38d345 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -129,25 +129,25 @@ pub fn sumcheck_prove_plain( periodic_columns: PeriodicTable, eval_point: &[E], ) -> Vec { - let mut evaluations = vec![]; + let mut evaluations = Vec::with_capacity(periodic_columns.num_columns()); for col in periodic_columns.table() { let ml = MultiLinearPoly::from_evaluations(col.to_vec()); let num_variables = ml.num_variables(); - let point = &eval_point[..num_variables]; - let evaluation = ml.evaluate(point); + let evaluation = ml.evaluate(&eval_point[&eval_point.len() - num_variables..]); evaluations.push(evaluation) } evaluations diff --git a/winterfell/src/tests/logup_gkr_periodic.rs b/winterfell/src/tests/logup_gkr_periodic.rs index 849cbbd5d..9b64b24d5 100644 --- a/winterfell/src/tests/logup_gkr_periodic.rs +++ b/winterfell/src/tests/logup_gkr_periodic.rs @@ -23,7 +23,7 @@ use crate::{ #[test] fn test_logup_gkr_periodic() { let aux_trace_width = 1; - let trace = LogUpGkrPeriodic::new(2_usize.pow(12), aux_trace_width); + let trace = LogUpGkrPeriodic::new(2_usize.pow(13), aux_trace_width); let prover = LogUpGkrPeriodicProver::new(aux_trace_width); let proof = prover.prove(trace).unwrap(); From 8eef5a79154ddea0503ecd2914fb503dc36cad0c Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 23 Sep 2024 13:33:53 +0200 Subject: [PATCH 03/44] wip --- air/Cargo.toml | 1 + air/src/air/logup_gkr/mod.rs | 14 ++- prover/src/logup_gkr/mod.rs | 144 +++++++++++++------------- prover/src/logup_gkr/prover.rs | 157 +++++++++++++++++++++-------- sumcheck/Cargo.toml | 2 +- sumcheck/src/lib.rs | 30 +++++- sumcheck/src/prover/high_degree.rs | 152 ++++++++++++++++++---------- sumcheck/src/prover/mod.rs | 3 +- sumcheck/src/prover/plain.rs | 147 +++++++++++++++++++++++++++ sumcheck/src/verifier/mod.rs | 143 ++++++++++++++++++++------ verifier/Cargo.toml | 1 + verifier/src/logup_gkr/mod.rs | 105 ++++++++++++++----- 12 files changed, 661 insertions(+), 238 deletions(-) diff --git a/air/Cargo.toml b/air/Cargo.toml index 4d1b0641f..e461ecf57 100644 --- a/air/Cargo.toml +++ b/air/Cargo.toml @@ -25,6 +25,7 @@ fri = { version = "0.9", path = "../fri", package = "winter-fri", default-featur libm = "0.2.8" math = { version = "0.9", path = "../math", package = "winter-math", default-features = false } utils = { version = "0.9", path = "../utils/core", package = "winter-utils", default-features = false } +libc-print = "0.1.23" [dev-dependencies] rand-utils = { version = "0.9", path = "../utils/rand", package = "winter-rand-utils" } diff --git a/air/src/air/logup_gkr/mod.rs b/air/src/air/logup_gkr/mod.rs index 628d442c2..27a9d6c78 100644 --- a/air/src/air/logup_gkr/mod.rs +++ b/air/src/air/logup_gkr/mod.rs @@ -4,6 +4,7 @@ // LICENSE file in the root directory of this source tree. use alloc::vec::Vec; +use libc_print::libc_println; use core::marker::PhantomData; use crypto::{ElementHasher, RandomCoin}; @@ -109,7 +110,7 @@ pub trait LogUpGkrEvaluator: Clone + Sync { /// [1]: https://eprint.iacr.org/2023/1284 fn generate_univariate_iop_for_multi_linear_opening_data( &self, - openings: Vec, + openings: Vec>, eval_point: Vec, public_coin: &mut impl RandomCoin, ) -> GkrData @@ -117,17 +118,26 @@ pub trait LogUpGkrEvaluator: Clone + Sync { E: FieldElement, H: ElementHasher, { + let openings: Vec = + openings[0].clone().chunks(2).flat_map(|ops| [ops[0], ops[1]]).collect(); public_coin.reseed(H::hash_elements(&openings)); + let folding_randomness: E = public_coin.draw().expect("failed to generate randomness"); + let batched_openings: Vec = + openings.chunks(2).map(|p| p[0] + folding_randomness * (p[1] - p[0])).collect(); + let mut batching_randomness = Vec::with_capacity(openings.len() - 1); for _ in 0..openings.len() - 1 { batching_randomness.push(public_coin.draw().expect("failed to generate randomness")) } + let mut eval_point = eval_point; + eval_point.push(folding_randomness); + libc_println!("folding ranomnes {:?}", folding_randomness); GkrData::new( LagrangeKernelRandElements::new(eval_point), batching_randomness, - openings, + batched_openings, self.get_oracles().to_vec(), ) } diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 7375c889a..bcbbbd33c 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -56,7 +56,7 @@ pub use utils::rayon::{current_num_threads as rayon_num_threads, prelude::*}; /// This means that layer ν will be the output layer and will consist of four values /// (p_0[ν - 1], p_1[ν - 1], p_0[ν - 1], p_1[ν - 1]) ∈ 𝔽^ν. pub struct EvaluatedCircuit { - layer_polys: Vec>, + layer_polys: Vec>>, } impl EvaluatedCircuit { @@ -72,10 +72,10 @@ impl EvaluatedCircuit { let mut current_layer = Self::generate_input_layer(main_trace_columns, evaluator, log_up_randomness); - while current_layer.num_wires() > 1 { + while current_layer[0].num_wires() > 1 { let next_layer = Self::compute_next_layer(¤t_layer); - layer_polys.push(CircuitLayerPolys::from_circuit_layer(current_layer)); + layer_polys.push(CircuitLayerPolys::from_circuit_layer(¤t_layer)); current_layer = next_layer; } @@ -88,21 +88,25 @@ impl EvaluatedCircuit { /// Note that the return type is a slice of [`CircuitLayerPolys`] as opposed to /// [`CircuitLayer`], since the evaluated layers are stored in a representation which can be /// proved using GKR. - pub fn layers(self) -> Vec> { + pub fn layers(self) -> Vec>> { self.layer_polys } /// Returns the numerator/denominator polynomials representing the output layer of the circuit. - pub fn output_layer(&self) -> &CircuitLayerPolys { + pub fn output_layers(&self) -> &Vec> { self.layer_polys.last().expect("circuit has at least one layer") } /// Evaluates the output layer at `query`, where the numerators of the output layer are treated /// as evaluations of a multilinear polynomial, and similarly for the denominators. - pub fn evaluate_output_layer(&self, query: E) -> (E, E) { - let CircuitLayerPolys { numerators, denominators } = self.output_layer(); + pub fn evaluate_output_layer(&self, query: E) -> Vec<(E, E)> { + let mut res = vec![]; + for output_layer in self.output_layers().iter() { + let CircuitLayerPolys { numerators, denominators } = output_layer; - (numerators.evaluate(&[query]), denominators.evaluate(&[query])) + res.push((numerators.evaluate(&[query]), denominators.evaluate(&[query]))) + } + res } // HELPERS @@ -111,76 +115,66 @@ impl EvaluatedCircuit { /// Generates the input layer of the circuit from the main trace columns and some randomness /// provided by the verifier. fn generate_input_layer( - trace: &impl Trace, + main_trace: &impl Trace, evaluator: &impl LogUpGkrEvaluator, log_up_randomness: &[E], - ) -> CircuitLayer { + ) -> Vec> { let num_fractions = evaluator.get_num_fractions(); let periodic_values = evaluator.build_periodic_values(trace.main_segment().num_rows()); - let mut input_layer_wires = - unsafe { uninit_vector(trace.main_segment().num_rows() * num_fractions) }; - let num_cols = trace.main_segment().num_cols(); - let num_oracles = evaluator.get_oracles().len(); - let num_periodic_cols = periodic_values.num_columns(); - - batch_iter_mut!( - &mut input_layer_wires, - 1024, - |batch: &mut [CircuitWire], batch_offset: usize| { - let mut main_frame = EvaluationFrame::new(num_cols); - let mut query = vec![E::BaseField::ZERO; num_oracles]; - let mut periodic_values_row = vec![E::BaseField::ZERO; num_periodic_cols]; - let mut numerators = vec![E::ZERO; num_fractions]; - let mut denominators = vec![E::ZERO; num_fractions]; - - let row_offset = batch_offset / num_fractions; - let batch_size = batch.len(); - let num_rows_per_batch = batch_size / num_fractions; - - for i in - (0..trace.main_segment().num_rows()).skip(row_offset).take(num_rows_per_batch) - { - trace.read_main_frame(i, &mut main_frame); - periodic_values.fill_periodic_values_at(i, &mut periodic_values_row); - evaluator.build_query(&main_frame, &mut query); - - evaluator.evaluate_query( - &query, - &periodic_values_row, - log_up_randomness, - &mut numerators, - &mut denominators, - ); - - let n = (i - row_offset) * num_fractions; - for ((wire, numerator), denominator) in batch[n..n + num_fractions] - .iter_mut() - .zip(numerators.iter()) - .zip(denominators.iter()) - { - *wire = CircuitWire::new(*numerator, *denominator); - } - } - } - ); - - CircuitLayer::new(input_layer_wires) + let mut input_layer_wires: Vec> = + // Vec::with_capacity(main_trace.main_segment().num_rows() * num_fractions); + vec![vec![]; num_fractions]; + let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols()); + + let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; + let mut periodic_values_row = vec![E::BaseField::ZERO; periodic_values.num_columns()]; + let mut numerators = vec![E::ZERO; num_fractions]; + let mut denominators = vec![E::ZERO; num_fractions]; + for i in 0..main_trace.main_segment().num_rows() { + main_trace.read_main_frame(i, &mut main_frame); + periodic_values.fill_periodic_values_at(i, &mut periodic_values_row); + evaluator.build_query(&main_frame, &mut query); + + evaluator.evaluate_query( + &query, + &periodic_values_row, + log_up_randomness, + &mut numerators, + &mut denominators, + ); + numerators + .iter() + .zip(denominators.iter()) + .zip(input_layer_wires.iter_mut()) + .for_each(|((numerator, denominator), circuit_input_layer)| { + circuit_input_layer.push(CircuitWire::new(*numerator, *denominator)) + }); + } + + input_layer_wires + .iter() + .map(|input_layer| CircuitLayer::new(input_layer.to_vec())) + .collect() } /// Computes the subsequent layer of the circuit from a given layer. - fn compute_next_layer(prev_layer: &CircuitLayer) -> CircuitLayer { - let next_layer_wires = chunks!(prev_layer.wires(), 2) - .map(|input_wires| { - let left_input_wire = input_wires[0]; - let right_input_wire = input_wires[1]; - - // output wire - left_input_wire + right_input_wire - }) - .collect(); - - CircuitLayer::new(next_layer_wires) + fn compute_next_layer(prev_layers: &[CircuitLayer]) -> Vec> { + let mut next_layers = vec![]; + for prev_layer in prev_layers.iter() { + let next_layer_wires = chunks!(prev_layer.wires(), 2) + .map(|input_wires| { + let left_input_wire = input_wires[0]; + let right_input_wire = input_wires[1]; + + // output wire + left_input_wire + right_input_wire + }) + .collect(); + + next_layers.push(CircuitLayer::new(next_layer_wires)) + } + next_layers } } @@ -199,8 +193,12 @@ impl CircuitLayerPolys where E: FieldElement, { - pub fn from_circuit_layer(layer: CircuitLayer) -> Self { - Self::from_wires(layer.wires) + pub fn from_circuit_layer(layers: &[CircuitLayer]) -> Vec { + let mut result = vec![]; + for layer in layers { + result.push(Self::from_wires(layer.wires.clone())) + } + result } pub fn from_wires(wires: Vec>) -> Self { @@ -326,7 +324,7 @@ where #[derive(Debug)] pub struct GkrClaim { pub evaluation_point: Vec, - pub claimed_evaluation: (E, E), + pub claimed_evaluation: Vec<(E, E)>, } /// We receive our 4 multilinear polynomials which were evaluated at a random point: diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 90cd3d2ed..afa53163e 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -4,12 +4,13 @@ use air::{LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; use sumcheck::{ - sum_check_prove_higher_degree, sumcheck_prove_plain, BeforeFinalLayerProof, CircuitOutput, - EqFunction, FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, + sum_check_prove_higher_degree, sumcheck_prove_plain_batched, + BeforeFinalLayerProof, CircuitOutput, EqFunction, FinalLayerProof, GkrCircuitProof, + MultiLinearPoly, SumCheckProof, }; use tracing::instrument; -use super::{reduce_layer_claim, CircuitLayerPolys, EvaluatedCircuit, GkrClaim, GkrProverError}; +use super::{CircuitLayerPolys, EvaluatedCircuit, GkrClaim, GkrProverError}; use crate::{matrix::ColMatrix, Trace}; // PROVER @@ -71,10 +72,11 @@ pub fn prove_gkr( let circuit = EvaluatedCircuit::new(main_trace, evaluator, &logup_randomness)?; // include the circuit output as part of the final proof - let CircuitLayerPolys { numerators, denominators } = circuit.output_layer().clone(); + let output_layers = circuit.output_layers().clone(); // run the GKR prover for all layers except the input layer - let (before_final_layer_proofs, gkr_claim) = prove_intermediate_layers(circuit, public_coin)?; + let (before_final_layer_proofs, gkr_claim, tensored_circuit_batching_randomness) = + prove_intermediate_layers(circuit, public_coin)?; // build the MLEs of the relevant main trace columns let main_trace_mls = @@ -89,11 +91,23 @@ pub fn prove_gkr( main_trace_mls, periodic_table, gkr_claim, + &tensored_circuit_batching_randomness, public_coin, )?; + let mut numerators_all_circuits = vec![]; + let mut denominators_all_circuits = vec![]; + for output_layer in output_layers { + let CircuitLayerPolys { numerators, denominators } = output_layer; + numerators_all_circuits.push(numerators); + denominators_all_circuits.push(denominators); + } + Ok(GkrCircuitProof { - circuit_outputs: CircuitOutput { numerators, denominators }, + circuit_outputs: CircuitOutput { + numerators: numerators_all_circuits, + denominators: denominators_all_circuits, + }, before_final_layer_proofs, final_layer_proof, }) @@ -111,23 +125,33 @@ fn prove_input_layer< multi_linear_ext_polys: Vec>, periodic_table: PeriodicTable, claim: GkrClaim, + tensored_batching_randomness: &[E], transcript: &mut C, ) -> Result, GkrProverError> { // parse the [GkrClaim] resulting from the previous GKR layer - let GkrClaim { evaluation_point, claimed_evaluation } = claim; - - transcript.reseed(H::hash_elements(&[claimed_evaluation.0, claimed_evaluation.1])); + let GkrClaim { + evaluation_point, + claimed_evaluation: claimed_evaluations, + } = claim; + for claimed_evaluation in claimed_evaluations.iter() { + transcript.reseed(H::hash_elements(&[claimed_evaluation.0, claimed_evaluation.1])); + } let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; - let claim = claimed_evaluation.0 + claimed_evaluation.1 * r_batch; + let mut full_claim = E::ZERO; + for (circuit_idx, claimed_evaluation) in claimed_evaluations.iter().enumerate() { + let claim = claimed_evaluation.0 + claimed_evaluation.1 * r_batch; + full_claim += claim * tensored_batching_randomness[circuit_idx] + } let proof = sum_check_prove_higher_degree( evaluator, evaluation_point, - claim, + full_claim, r_batch, log_up_randomness, multi_linear_ext_polys, periodic_table, + tensored_batching_randomness, transcript, )?; @@ -172,18 +196,33 @@ fn prove_intermediate_layers< >( circuit: EvaluatedCircuit, transcript: &mut C, -) -> Result<(BeforeFinalLayerProof, GkrClaim), GkrProverError> { +) -> Result<(BeforeFinalLayerProof, GkrClaim, Vec), GkrProverError> { // absorb the circuit output layer. This corresponds to sending the four values of the output // layer to the verifier. The verifier then replies with a challenge `r` in order to evaluate // `p` and `q` at `r` as multi-linears. - let CircuitLayerPolys { numerators, denominators } = circuit.output_layer(); - let mut evaluations = numerators.evaluations().to_vec(); - evaluations.extend_from_slice(denominators.evaluations()); - transcript.reseed(H::hash_elements(&evaluations)); + let output_layers = circuit.output_layers(); + for output_layer in output_layers.into_iter() { + let mut evaluations = output_layer.numerators.evaluations().to_vec(); + evaluations.extend_from_slice(output_layer.denominators.evaluations()); + transcript.reseed(H::hash_elements(&evaluations)); + } // generate the challenge and reduce [p0, p1, q0, q1] to [pr, qr] let r = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; - let mut claimed_evaluation = circuit.evaluate_output_layer(r); + let mut claimed_evaluations = circuit.evaluate_output_layer(r); + let num_circuits = claimed_evaluations.len(); + let log_num_circuits = num_circuits.ilog2(); + assert_eq!(1 << log_num_circuits, num_circuits); + + let mut circuit_batching_randomness: Vec = vec![]; + for _ in 0..log_num_circuits { + let batching_r = + transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; + circuit_batching_randomness.push(batching_r); + } + + let tensored_circuit_batching_randomness = + EqFunction::new(circuit_batching_randomness.into()).evaluations(); let mut layer_proofs: Vec> = Vec::new(); let mut evaluation_point = vec![r]; @@ -199,36 +238,31 @@ fn prove_intermediate_layers< // construct the Lagrange kernel evaluated at the previous GKR round randomness let mut eq_mle = EqFunction::ml_at(evaluation_point.into()); - let (numerators, denominators) = inner_layer.into_numerators_denominators(); - // run the sumcheck protocol let proof = sum_check_prove_num_rounds_degree_3( - claimed_evaluation, - numerators, - denominators, + inner_layer, + &claimed_evaluations, &mut eq_mle, + &tensored_circuit_batching_randomness, transcript, )?; // sample a random challenge to reduce claims - transcript.reseed(H::hash_elements(&proof.openings_claim.openings)); + for tmp in proof.openings_claim.openings.iter() { + transcript.reseed(H::hash_elements(tmp)); + } let r_layer = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; // reduce the claim - claimed_evaluation = { - let left_numerators_opening = proof.openings_claim.openings[0]; - let right_numerators_opening = proof.openings_claim.openings[1]; - let left_denominators_opening = proof.openings_claim.openings[2]; - let right_denominators_opening = proof.openings_claim.openings[3]; - - reduce_layer_claim( - left_numerators_opening, - right_numerators_opening, - left_denominators_opening, - right_denominators_opening, - r_layer, - ) - }; + for (j, ops) in proof.openings_claim.openings.iter().enumerate() { + let p0 = ops[0]; + let p1 = ops[1]; + let q0 = ops[2]; + let q1 = ops[3]; + + let reduced_claim = (p0 + r_layer * (p1 - p0), q0 + r_layer * (q1 - q0)); + claimed_evaluations[j] = reduced_claim; + } // collect the randomness used for the current layer let mut ext = proof.openings_claim.eval_point.clone(); @@ -240,7 +274,11 @@ fn prove_intermediate_layers< Ok(( BeforeFinalLayerProof { proof: layer_proofs }, - GkrClaim { evaluation_point, claimed_evaluation }, + GkrClaim { + evaluation_point, + claimed_evaluation: claimed_evaluations, + }, + tensored_circuit_batching_randomness, )) } @@ -251,18 +289,49 @@ fn sum_check_prove_num_rounds_degree_3< C: RandomCoin, H: ElementHasher, >( - claim: (E, E), - p: MultiLinearPoly, - q: MultiLinearPoly, + inner_layers: Vec>, + claims: &[(E, E)], eq: &mut MultiLinearPoly, + tensored_batching_randomness: &[E], transcript: &mut C, ) -> Result, GkrProverError> { // generate challenge to batch two sumchecks - transcript.reseed(H::hash_elements(&[claim.0, claim.1])); + for claim in claims { + transcript.reseed(H::hash_elements(&[claim.0, claim.1])); + } let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; - let claim = claim.0 + claim.1 * r_batch; + let mut batched_claims = vec![]; + for claim in claims { + let claim = claim.0 + claim.1 * r_batch; + batched_claims.push(claim) + } - let proof = sumcheck_prove_plain(claim, r_batch, p, q, eq, transcript)?; + let mut numerators_accross_circuits_0 = vec![]; + let mut denominators_accross_circuits_0 = vec![]; + let mut numerators_accross_circuits_1 = vec![]; + let mut denominators_accross_circuits_1 = vec![]; + + for tu in inner_layers { + let CircuitLayerPolys { numerators, denominators } = tu; + let (p0, p1) = numerators.project_least_significant_variable(); + let (q0, q1) = denominators.project_least_significant_variable(); + numerators_accross_circuits_0.push(p0); + numerators_accross_circuits_1.push(p1); + denominators_accross_circuits_0.push(q0); + denominators_accross_circuits_1.push(q1) + } + + let proof = sumcheck_prove_plain_batched( + &batched_claims, + r_batch, + numerators_accross_circuits_0, + numerators_accross_circuits_1, + denominators_accross_circuits_0, + denominators_accross_circuits_1, + eq, + tensored_batching_randomness, + transcript, + )?; Ok(proof) } diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 7db2e8058..4563890bc 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -41,7 +41,7 @@ utils = { version = "0.9", path = "../utils/core", package = "winter-utils", def rayon = { version = "1.8", optional = true } smallvec = { version = "1.13", default-features = false } thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false } - +libc-print = "0.1.23" [dev-dependencies] criterion = "0.5" rand-utils = { version = "0.9", path = "../utils/rand", package = "winter-rand-utils" } \ No newline at end of file diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index b7f670a9d..4672fdb7d 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -37,7 +37,7 @@ pub use multilinear::{inner_product, EqFunction, MultiLinearPoly}; #[derive(Clone, Debug)] pub struct FinalOpeningClaim { pub eval_point: Vec, - pub openings: Vec, + pub openings: Vec>, } impl Serializable for FinalOpeningClaim { @@ -229,8 +229,8 @@ where /// Holds the output layer of an [`EvaluatedCircuit`]. #[derive(Clone, Debug)] pub struct CircuitOutput { - pub numerators: MultiLinearPoly, - pub denominators: MultiLinearPoly, + pub numerators: Vec>, + pub denominators: Vec>, } impl Serializable for CircuitOutput @@ -250,8 +250,8 @@ where { fn read_from(source: &mut R) -> Result { Ok(Self { - numerators: MultiLinearPoly::read_from(source)?, - denominators: MultiLinearPoly::read_from(source)?, + numerators: Vec::>::read_from(source)?, + denominators: Vec::>::read_from(source)?, }) } } @@ -279,3 +279,23 @@ pub fn evaluate_composition_poly( .map(|(p, (q, eq_w))| *eq_w * comb_func(p[0], p[1], q[0], q[1], eq_eval, r_sum_check)) .fold(E::ZERO, |acc, x| acc + x) } +/// The non-linear composition polynomial of the LogUp-GKR protocol specific to the input layer. +pub fn evaluate_composition_poly_2( + eq_at_mu: &[E], + numerators_zero: &[E], + denominators_zero: &[E], + numerators_one: &[E], + denominators_one: &[E], + eq_eval: E, + r_sum_check: E, +) -> E { + numerators_zero + .iter() + .zip( + numerators_one + .iter() + .zip(denominators_zero.iter().zip(denominators_one.iter().zip(eq_at_mu.iter()))) + ) + .map(|(p0, (p1, (q0, (q1, eq_w))))| {*eq_w * comb_func(*p0, *p1, *q0, *q1, eq_eval, r_sum_check)}) + .fold(E::ZERO, |acc, x| acc + x) +} \ No newline at end of file diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index d00c01e59..e8652a15c 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -7,14 +7,14 @@ use alloc::vec::Vec; use air::{LogUpGkrEvaluator, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; +use libc_print::libc_println; use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; use super::SumCheckProverError; use crate::{ - evaluate_composition_poly, CompressedUnivariatePolyEvals, EqFunction, FinalOpeningClaim, - MultiLinearPoly, RoundProof, SumCheckProof, SumCheckRoundClaim, + evaluate_composition_poly, evaluate_composition_poly_2, CompressedUnivariatePolyEvals, EqFunction, FinalOpeningClaim, MultiLinearPoly, RoundProof, SumCheckProof, SumCheckRoundClaim }; /// A sum-check prover for the input layer which can accommodate non-linear expressions in @@ -161,28 +161,23 @@ pub fn sum_check_prove_higher_degree< log_up_randomness: Vec, mut mls: Vec>, mut periodic_table: PeriodicTable, + tensored_circuits_batching: &[E], coin: &mut impl RandomCoin, ) -> Result, SumCheckProverError> { - let num_rounds = mls[0].num_variables(); + let num_rounds = mls[0].num_variables() - 1; let mut round_proofs = vec![]; - // split the evaluation point into two points of dimension mu and nu, respectively - let mu = evaluator.get_num_fractions().trailing_zeros() - 1; - let (evaluation_point_nu, evaluation_point_mu) = - evaluation_point.split_at(evaluation_point.len() - mu as usize); - - let eq_mu = EqFunction::ml_at(evaluation_point_mu.into()).evaluations().to_vec(); - let mut eq_nu = EqFunction::ml_at(evaluation_point_nu.into()); + let mut eq_mle = EqFunction::ml_at(evaluation_point.into()); // setup first round claim let mut current_round_claim = SumCheckRoundClaim { eval_point: vec![], claim }; // run the first round of the protocol let round_poly_evals = sumcheck_round( - &eq_mu, + &tensored_circuits_batching, evaluator, - &eq_nu, + &eq_mle, &mls, &periodic_table, &log_up_randomness, @@ -206,7 +201,7 @@ pub fn sum_check_prove_higher_degree< // fold each multi-linear using the round challenge mls.iter_mut() .for_each(|ml| ml.bind_least_significant_variable(round_challenge)); - eq_nu.bind_least_significant_variable(round_challenge); + eq_mle.bind_least_significant_variable(round_challenge); // fold each periodic multi-linear using the round challenge periodic_table.bind_least_significant_variable(round_challenge); @@ -214,9 +209,9 @@ pub fn sum_check_prove_higher_degree< // run the i-th round of the protocol using the folded multi-linears for the new reduced // claim. This basically computes the s_i polynomial. let round_poly_evals = sumcheck_round( - &eq_mu, + &tensored_circuits_batching, evaluator, - &eq_nu, + &eq_mle, &mls, &periodic_table, &log_up_randomness, @@ -245,10 +240,11 @@ pub fn sum_check_prove_higher_degree< let SumCheckRoundClaim { eval_point, claim: _claim } = reduce_claim(&round_proofs[num_rounds - 1], current_round_claim, round_challenge); - let openings = mls.iter_mut().map(|ml| ml.evaluations()[0]).collect(); + let openings :Vec = mls.into_iter().flat_map(|ml| [ml.evaluations()[0], ml.evaluations()[1]]).collect(); + libc_println!("openings prover {:?}", openings); Ok(SumCheckProof { - openings_claim: FinalOpeningClaim { eval_point, openings }, + openings_claim: FinalOpeningClaim { eval_point, openings: vec![openings] }, round_proofs, }) } @@ -306,75 +302,114 @@ fn sumcheck_round( let num_mls = mls.len(); let num_periodic = periodic_table.num_columns(); let num_vars = mls[0].num_variables(); - let num_rounds = num_vars - 1; + let num_rounds = num_vars - 1 - 1; #[cfg(not(feature = "concurrent"))] let evaluations = { - let mut evals_one = vec![E::ZERO; num_mls]; - let mut evals_zero = vec![E::ZERO; num_mls]; - let mut evals_x = vec![E::ZERO; num_mls]; + let mut evals_one_zero = vec![E::ZERO; num_mls]; + let mut evals_one_one = vec![E::ZERO; num_mls]; + let mut evals_zero_zero = vec![E::ZERO; num_mls]; + let mut evals_zero_one = vec![E::ZERO; num_mls]; + + let mut evals_x_zero = vec![E::ZERO; num_mls]; + let mut evals_x_one = vec![E::ZERO; num_mls]; + + let mut evals_periodic_zero_zero = vec![E::ZERO; num_periodic]; + let mut evals_periodic_zero_one = vec![E::ZERO; num_periodic]; + let mut evals_periodic_one_zero = vec![E::ZERO; num_periodic]; + let mut evals_periodic_one_one = vec![E::ZERO; num_periodic]; + + let mut evals_periodic_x_zero = vec![E::ZERO; num_periodic]; + let mut evals_periodic_x_one = vec![E::ZERO; num_periodic]; - let mut evals_periodic_one = vec![E::ZERO; num_periodic]; - let mut evals_periodic_zero = vec![E::ZERO; num_periodic]; - let mut evals_periodic_x = vec![E::ZERO; num_periodic]; let mut eq_x = E::ZERO; - let mut deltas = vec![E::ZERO; num_mls]; - let mut deltas_periodic = vec![E::ZERO; num_periodic]; + let mut deltas_zero = vec![E::ZERO; num_mls]; + let mut deltas_one = vec![E::ZERO; num_mls]; + let mut deltas_periodic_zero = vec![E::ZERO; num_periodic]; + let mut deltas_periodic_one = vec![E::ZERO; num_periodic]; let mut eq_delta = E::ZERO; - let mut numerators = vec![E::ZERO; evaluator.get_num_fractions()]; - let mut denominators = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut numerators_zero = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut denominators_zero = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut numerators_one = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut denominators_one = vec![E::ZERO; evaluator.get_num_fractions()]; (0..1 << num_rounds) .map(|i| { let mut total_evals = vec![E::ZERO; evaluator.max_degree()]; for (j, ml) in mls.iter().enumerate() { - evals_zero[j] = ml.evaluations()[i]; - evals_one[j] = ml.evaluations()[i + (1 << num_rounds)]; + evals_zero_zero[j] = ml.evaluations()[2 * i]; + evals_zero_one[j] = ml.evaluations()[2 * i + 1]; + evals_one_zero[j] = ml.evaluations()[2 * (i + num_rounds)]; + evals_one_one[j] = ml.evaluations()[2 * (i + num_rounds) + 1]; } let eq_at_zero = eq_ml.evaluations()[i]; - let eq_at_one = eq_ml.evaluations()[i + (1 << num_rounds)]; + let eq_at_one = eq_ml.evaluations()[i + ( num_rounds)]; // add evaluation of periodic columns - periodic_table.fill_periodic_values_at(i, &mut evals_periodic_zero); - periodic_table - .fill_periodic_values_at(i + (1 << num_rounds), &mut evals_periodic_one); + periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero_zero); + periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_zero_one); + periodic_table.fill_periodic_values_at(2 * (i + num_rounds), &mut evals_periodic_one_zero); + periodic_table.fill_periodic_values_at(2 * (i + num_rounds) + 1, &mut evals_periodic_one_one); + // compute the evaluation at 1 evaluator.evaluate_query( - &evals_one, - &evals_periodic_one, + &evals_one_zero, + &evals_periodic_one_zero, log_up_randomness, - &mut numerators, - &mut denominators, + &mut numerators_zero, + &mut denominators_zero, ); - total_evals[0] = evaluate_composition_poly( + evaluator.evaluate_query( + &evals_one_one, + &evals_periodic_one_one, + log_up_randomness, + &mut numerators_one, + &mut denominators_one, + ); + total_evals[0] = evaluate_composition_poly_2( eq_mu, - &numerators, - &denominators, + &numerators_zero, + &numerators_one, + &denominators_zero, + &denominators_one, eq_at_one, r_sum_check, ); // compute the evaluations at 2, ..., d_max points for i in 0..num_mls { - deltas[i] = evals_one[i] - evals_zero[i]; - evals_x[i] = evals_one[i]; + deltas_zero[i] = evals_one_zero[i] - evals_zero_zero[i]; + evals_x_zero[i] = evals_one_zero[i]; + deltas_one[i] = evals_one_one[i] - evals_zero_one[i]; + evals_x_one[i] = evals_one_one[i]; } for i in 0..num_periodic { - deltas_periodic[i] = evals_periodic_one[i] - evals_periodic_zero[i]; - evals_periodic_x[i] = evals_periodic_one[i]; + deltas_periodic_zero[i] = + evals_periodic_zero_one[i] - evals_periodic_zero_zero[i]; + evals_periodic_x_zero[i] = evals_periodic_zero_one[i]; + deltas_periodic_one[i] = evals_periodic_one_one[i] - evals_periodic_one_zero[i]; + evals_periodic_x_one[i] = evals_periodic_one_one[i]; } eq_delta = eq_at_one - eq_at_zero; eq_x = eq_at_one; for e in total_evals.iter_mut().skip(1) { - evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| { + evals_x_zero.iter_mut().zip(deltas_zero.iter()).for_each(|(evx, delta)| { *evx += *delta; }); - evals_periodic_x.iter_mut().zip(deltas_periodic.iter()).for_each( + evals_periodic_x_zero.iter_mut().zip(deltas_periodic_zero.iter()).for_each( + |(evx, delta)| { + *evx += *delta; + }, + ); + evals_x_one.iter_mut().zip(deltas_one.iter()).for_each(|(evx, delta)| { + *evx += *delta; + }); + evals_periodic_x_one.iter_mut().zip(deltas_periodic_one.iter()).for_each( |(evx, delta)| { *evx += *delta; }, @@ -382,16 +417,25 @@ fn sumcheck_round( eq_x += eq_delta; evaluator.evaluate_query( - &evals_x, - &evals_periodic_x, + &evals_x_zero, + &evals_periodic_x_zero, log_up_randomness, - &mut numerators, - &mut denominators, + &mut numerators_zero, + &mut denominators_zero, ); - *e = evaluate_composition_poly( + evaluator.evaluate_query( + &evals_x_one, + &evals_periodic_x_one, + log_up_randomness, + &mut numerators_one, + &mut denominators_one, + ); + *e = evaluate_composition_poly_2( eq_mu, - &numerators, - &denominators, + &numerators_zero, + &numerators_one, + &denominators_zero, + &denominators_one, eq_x, r_sum_check, ); diff --git a/sumcheck/src/prover/mod.rs b/sumcheck/src/prover/mod.rs index 13d35e551..0a25b727a 100644 --- a/sumcheck/src/prover/mod.rs +++ b/sumcheck/src/prover/mod.rs @@ -7,7 +7,8 @@ mod high_degree; pub use high_degree::sum_check_prove_higher_degree; mod plain; -pub use plain::sumcheck_prove_plain; +//pub use plain::sumcheck_prove_plain; +pub use plain::sumcheck_prove_plain_batched; mod error; pub use error::SumCheckProverError; diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index a9c38d345..db9864899 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -3,6 +3,7 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. +use alloc::vec::Vec; use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; #[cfg(feature = "concurrent")] @@ -15,6 +16,8 @@ use crate::{ SumCheckProof, }; +/* + /// Sum-check prover for non-linear multivariate polynomial of the simple LogUp-GKR. /// /// More specifically, the following function implements the logic of the sum-check prover as @@ -213,3 +216,147 @@ pub fn sumcheck_prove_plain>( + claims: &[E], + r_batch: E, + mut p0_s: Vec>, + mut p1_s: Vec>, + mut q0_s: Vec>, + mut q1_s: Vec>, + eq: &mut MultiLinearPoly, + tensored_batching_randomness: &[E], + transcript: &mut impl RandomCoin, +) -> Result, SumCheckProverError> { + let mut round_proofs = vec![]; + + let mut challenges = vec![]; + + let mut all_claim = E::ZERO; + for (circuit_id, claim) in claims.iter().enumerate() { + all_claim += *claim * tensored_batching_randomness[circuit_id]; + } + let num_rounds = p0_s[0].num_variables(); + for l in 0..num_rounds { + let mut all_round_poly_eval_at_1 = E::ZERO; + let mut all_round_poly_eval_at_2 = E::ZERO; + let mut all_round_poly_eval_at_3 = E::ZERO; + let len = p0_s[0].num_evaluations() / 2; + + for (circuit_id, (p0, (p1, (q0, q1)))) in p0_s + .iter_mut() + .zip(p1_s.iter_mut().zip(q0_s.iter_mut().zip(q1_s.iter_mut()))) + .enumerate() + { + let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = (0..len).fold( + (E::ZERO, E::ZERO, E::ZERO), + |(acc_point_1, acc_point_2, acc_point_3), i| { + let cur_len = len; + let round_poly_eval_at_1 = comb_func( + p0[i + cur_len], + p1[i + cur_len], + q0[i + cur_len], + q1[i + cur_len], + eq[i + cur_len], + r_batch, + ); + + let p0_delta = p0[i + cur_len] - p0[i]; + let p1_delta = p1[i + cur_len] - p1[i]; + let q0_delta = q0[i + cur_len] - q0[i]; + let q1_delta = q1[i + cur_len] - q1[i]; + let eq_delta = eq[i + cur_len] - eq[i]; + + let mut p0_eval_at_x = p0[i + cur_len] + p0_delta; + let mut p1_eval_at_x = p1[i + cur_len] + p1_delta; + let mut q0_eval_at_x = q0[i + cur_len] + q0_delta; + let mut q1_eval_at_x = q1[i + cur_len] + q1_delta; + let mut eq_evx = eq[i + cur_len] + eq_delta; + let round_poly_eval_at_2 = comb_func( + p0_eval_at_x, + p1_eval_at_x, + q0_eval_at_x, + q1_eval_at_x, + eq_evx, + r_batch, + ); + + p0_eval_at_x += p0_delta; + p1_eval_at_x += p1_delta; + q0_eval_at_x += q0_delta; + q1_eval_at_x += q1_delta; + eq_evx += eq_delta; + let round_poly_eval_at_3 = comb_func( + p0_eval_at_x, + p1_eval_at_x, + q0_eval_at_x, + q1_eval_at_x, + eq_evx, + r_batch, + ); + + ( + round_poly_eval_at_1 + acc_point_1, + round_poly_eval_at_2 + acc_point_2, + round_poly_eval_at_3 + acc_point_3, + ) + }, + ); + + all_round_poly_eval_at_1 += + round_poly_eval_at_1 * tensored_batching_randomness[circuit_id]; + all_round_poly_eval_at_2 += + round_poly_eval_at_2 * tensored_batching_randomness[circuit_id]; + all_round_poly_eval_at_3 += + round_poly_eval_at_3 * tensored_batching_randomness[circuit_id]; + } + + let evals = + smallvec![all_round_poly_eval_at_1, all_round_poly_eval_at_2, all_round_poly_eval_at_3]; + let compressed_round_poly_evals = CompressedUnivariatePolyEvals(evals); + let compressed_round_poly = compressed_round_poly_evals.to_poly(all_claim); + + // reseed with the s_i polynomial + transcript.reseed(H::hash_elements(&compressed_round_poly.0)); + let round_proof = RoundProof { + round_poly_coefs: compressed_round_poly.clone(), + }; + + let round_challenge = + transcript.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + + for (p0, (p1, (q0, q1))) in p0_s.iter_mut().zip(p1_s.iter_mut().zip(q0_s.iter_mut().zip(q1_s.iter_mut()))) { + // fold each multi-linear using the round challenge + p0.bind_least_significant_variable(round_challenge); + p1.bind_least_significant_variable(round_challenge); + q0.bind_least_significant_variable(round_challenge); + q1.bind_least_significant_variable(round_challenge); + + }eq.bind_least_significant_variable(round_challenge); + + // compute the new reduced round claim + all_claim = compressed_round_poly.evaluate_using_claim(&all_claim, &round_challenge); + + round_proofs.push(round_proof); + challenges.push(round_challenge); + } + + let mut openings = vec![]; + for (p0, (p1, (q0, q1))) in p0_s.iter_mut().zip(p1_s.iter_mut().zip(q0_s.iter_mut().zip(q1_s.iter_mut()))) { + openings.push(vec![p0[0], p1[0], q0[0], q1[0]]) + } + + + Ok(SumCheckProof { + openings_claim: FinalOpeningClaim { + eval_point: challenges, + openings, + }, + round_proofs, + }) +} diff --git a/sumcheck/src/verifier/mod.rs b/sumcheck/src/verifier/mod.rs index 98ed6d0e4..bddc9ceba 100644 --- a/sumcheck/src/verifier/mod.rs +++ b/sumcheck/src/verifier/mod.rs @@ -7,11 +7,11 @@ use alloc::vec::Vec; use air::{LogUpGkrEvaluator, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; +use libc_print::libc_println; use math::FieldElement; use crate::{ - comb_func, evaluate_composition_poly, EqFunction, FinalLayerProof, FinalOpeningClaim, - MultiLinearPoly, RoundProof, SumCheckProof, SumCheckRoundClaim, + comb_func, evaluate_composition_poly, evaluate_composition_poly_2, EqFunction, FinalLayerProof, FinalOpeningClaim, MultiLinearPoly, RoundProof, SumCheckProof, SumCheckRoundClaim }; /// Verifies sum-check proofs, as part of the GKR proof, for all GKR layers except for the last one @@ -22,31 +22,47 @@ pub fn verify_sum_check_intermediate_layers< >( proof: &SumCheckProof, gkr_eval_point: &[E], - claim: (E, E), + claims: &[(E, E)], + tensored_circuit_batching_randomness: &[E], transcript: &mut impl RandomCoin, ) -> Result, SumCheckVerifierError> { // generate challenge to batch sum-checks - transcript.reseed(H::hash_elements(&[claim.0, claim.1])); + for claim in claims { + transcript.reseed(H::hash_elements(&[claim.0, claim.1])); + } let r_batch: E = transcript .draw() .map_err(|_| SumCheckVerifierError::FailedToGenerateChallenge)?; - // compute the claim for the batched sum-check - let reduced_claim = claim.0 + claim.1 * r_batch; + let mut batched_claims = vec![]; + for claim in claims { + let claim = claim.0 + claim.1 * r_batch; + batched_claims.push(claim) + } + let mut full_claim = E::ZERO; + for (circuit_id, claim) in batched_claims.iter().enumerate() { + full_claim += *claim * tensored_circuit_batching_randomness[circuit_id] + } let SumCheckProof { openings_claim, round_proofs } = proof; - let final_round_claim = verify_rounds(reduced_claim, round_proofs, transcript)?; + let final_round_claim = verify_rounds(full_claim, round_proofs, transcript)?; assert_eq!(openings_claim.eval_point, final_round_claim.eval_point); - let p0 = openings_claim.openings[0]; - let p1 = openings_claim.openings[1]; - let q0 = openings_claim.openings[2]; - let q1 = openings_claim.openings[3]; - - let eq = EqFunction::new(gkr_eval_point.into()).evaluate(&openings_claim.eval_point); + let mut eval_batched_circuits = E::ZERO; + let eq = EqFunction::new(gkr_eval_point.into()).evaluate(&openings_claim.eval_point.clone()); + for (circuit_idx, openings) in openings_claim.openings.iter().enumerate() { + let p0 = openings[0].clone(); + let p1 = openings[1].clone(); + let q0 = openings[2].clone(); + let q1 = openings[3].clone(); - if comb_func(p0, p1, q0, q1, eq, r_batch) != final_round_claim.claim { + eval_batched_circuits += comb_func(p0, p1, q0, q1, eq, r_batch) + * tensored_circuit_batching_randomness[circuit_idx] + } + libc_println!("we are here !!!!"); + if eval_batched_circuits != final_round_claim.claim { + //libc_println!("we are here !!!!"); return Err(SumCheckVerifierError::FinalEvaluationCheckFailed); } @@ -63,29 +79,46 @@ pub fn verify_sum_check_input_layer, log_up_randomness: Vec, gkr_eval_point: &[E], - claim: (E, E), + claim: Vec<(E, E)>, + tensored_circuit_batching_randomness: &[E], transcript: &mut impl RandomCoin, ) -> Result, SumCheckVerifierError> { + for claimed_evaluation in claim.iter() { + transcript.reseed(H::hash_elements(&[claimed_evaluation.0, claimed_evaluation.1])); + } // generate challenge to batch sum-checks - transcript.reseed(H::hash_elements(&[claim.0, claim.1])); + let r_batch: E = transcript .draw() .map_err(|_| SumCheckVerifierError::FailedToGenerateChallenge)?; - // compute the claim for the batched sum-check - let reduced_claim = claim.0 + claim.1 * r_batch; + let mut batched_claims = vec![]; + for claimed_evaluation in claim.iter() { + let claim = claimed_evaluation.0 + claimed_evaluation.1 * r_batch; + batched_claims.push(claim) + } + + let mut full_claim = E::ZERO; + for (circuit_id, claim) in batched_claims.iter().enumerate() { + full_claim += *claim * tensored_circuit_batching_randomness[circuit_id] + } // verify the sum-check proof let SumCheckRoundClaim { eval_point, claim } = - verify_rounds(reduced_claim, &proof.0.round_proofs, transcript)?; + verify_rounds(full_claim, &proof.0.round_proofs, transcript)?; // execute the final evaluation check if proof.0.openings_claim.eval_point != eval_point { + libc_println!("proof.0.openings_claim.eval_point {:?}", proof.0.openings_claim.eval_point); + libc_println!("eval_point {:?}", eval_point); return Err(SumCheckVerifierError::WrongOpeningPoint); } + libc_println!("HERE!"); - let mut numerators = vec![E::ZERO; evaluator.get_num_fractions()]; - let mut denominators = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut numerators_zero = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut numerators_one = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut denominators_zero = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut denominators_one = vec![E::ZERO; evaluator.get_num_fractions()]; let mu = evaluator.get_num_fractions().trailing_zeros() - 1; let (evaluation_point_nu, evaluation_point_mu) = @@ -98,19 +131,45 @@ pub fn verify_sum_check_input_layer( +fn evaluate_periodic_columns_at_old( periodic_columns: PeriodicTable, eval_point: &[E], ) -> Vec { @@ -173,3 +232,27 @@ fn evaluate_periodic_columns_at( } evaluations } +fn evaluate_periodic_columns_at( + periodic_columns: PeriodicTable, + eval_point: &[E], +) -> (Vec, Vec) { + let mut eval_point_zero = vec![E::ZERO]; + let mut eval_point_one = vec![E::ZERO]; + eval_point_zero.extend_from_slice(&eval_point); + eval_point_one.extend_from_slice(&eval_point); + + let mut evaluations_zero = vec![]; + let mut evaluations_one = vec![]; + for col in periodic_columns.table() { + let ml = MultiLinearPoly::from_evaluations(col.to_vec()); + let num_variables = ml.num_variables(); + let point_zero = &eval_point_zero[..num_variables]; + let point_one = &eval_point_one[..num_variables]; + + let evaluation_zero = ml.evaluate(point_zero); + evaluations_zero.push(evaluation_zero); + let evaluation_one= ml.evaluate(point_one); + evaluations_one.push(evaluation_one) + } + (evaluations_zero, evaluations_one) +} \ No newline at end of file diff --git a/verifier/Cargo.toml b/verifier/Cargo.toml index 0c07d493c..b60ab0347 100644 --- a/verifier/Cargo.toml +++ b/verifier/Cargo.toml @@ -27,6 +27,7 @@ math = { version = "0.9", path = "../math", package = "winter-math", default-fea sumcheck = { version = "0.1", path = "../sumcheck", package = "winter-sumcheck", default-features = false } thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false } utils = { version = "0.9", path = "../utils/core", package = "winter-utils", default-features = false } +libc-print = "0.1.23" # Allow math in docs [package.metadata.docs.rs] diff --git a/verifier/src/logup_gkr/mod.rs b/verifier/src/logup_gkr/mod.rs index 8a65e3543..fa35e9c54 100644 --- a/verifier/src/logup_gkr/mod.rs +++ b/verifier/src/logup_gkr/mod.rs @@ -2,10 +2,10 @@ use alloc::vec::Vec; use air::{Air, LogUpGkrEvaluator}; use crypto::{ElementHasher, RandomCoin}; +use libc_print::libc_println; use math::FieldElement; use sumcheck::{ - verify_sum_check_input_layer, verify_sum_check_intermediate_layers, CircuitOutput, - FinalOpeningClaim, GkrCircuitProof, SumCheckVerifierError, + verify_sum_check_input_layer, verify_sum_check_intermediate_layers, CircuitOutput, EqFunction, FinalOpeningClaim, GkrCircuitProof, SumCheckVerifierError }; /// Verifies the validity of a GKR proof for a LogUp-GKR relation. @@ -34,32 +34,73 @@ pub fn verify_gkr< } = proof; let CircuitOutput { numerators, denominators } = circuit_outputs; - let p0 = numerators.evaluations()[0]; - let p1 = numerators.evaluations()[1]; - let q0 = denominators.evaluations()[0]; - let q1 = denominators.evaluations()[1]; - - // make sure that both denominators are not equal to E::ZERO - if q0 == E::ZERO || q1 == E::ZERO { - return Err(VerifierError::ZeroOutputDenominator); - } - - // check that the output matches the expected `claim` let claim = evaluator.compute_claim(pub_inputs, &logup_randomness); - if (p0 * q1 + p1 * q0) / (q0 * q1) != claim { + let mut num_acc = E::ZERO; + let mut den_acc = E::ONE; + for (circuit_id, (nums, dens)) in + numerators.into_iter().zip(denominators.into_iter()).enumerate() + { + let mut evaluations = nums.evaluations().to_vec(); + evaluations.extend_from_slice(&dens.evaluations()); + transcript.reseed(H::hash_elements(&evaluations)); + + let p0 = nums.evaluations()[0]; + let p1 = nums.evaluations()[1]; + let q0 = dens.evaluations()[0]; + let q1 = dens.evaluations()[1]; + + // make sure that both denominators are not equal to E::ZERO + if q0 == E::ZERO || q1 == E::ZERO { + libc_println!("p0 is zero"); + return Err(VerifierError::ZeroOutputDenominator); + } + + let cur_num = (p0 * q1 + p1 * q0); + let cur_den = q0 * q1; + + let new_num = num_acc * cur_den + den_acc * cur_num; + let new_den = den_acc * cur_den; + num_acc = new_num; + den_acc = new_den; + } + if num_acc != claim || den_acc == E::ZERO { return Err(VerifierError::MismatchingCircuitOutput); } + // generate the random challenge to reduce two claims into a single claim - let mut evaluations = numerators.evaluations().to_vec(); - evaluations.extend_from_slice(denominators.evaluations()); - transcript.reseed(H::hash_elements(&evaluations)); let r = transcript.draw().map_err(|_| VerifierError::FailedToGenerateChallenge)?; // reduce the claim - let p_r = p0 + r * (p1 - p0); - let q_r = q0 + r * (q1 - q0); - let mut reduced_claim = (p_r, q_r); + let mut reduced_claims = vec![]; + for (circuit_id, (nums, dens)) in + numerators.into_iter().zip(denominators.into_iter()).enumerate() + { + let p0 = nums.evaluations()[0]; + let p1 = nums.evaluations()[1]; + let q0 = dens.evaluations()[0]; + let q1 = dens.evaluations()[1]; + // reduce the claim + let p_r = p0 + r * (p1 - p0); + let q_r = q0 + r * (q1 - q0); + let reduced_claim = (p_r, q_r); + reduced_claims.push(reduced_claim) + } + + let num_circuits = reduced_claims.len(); + let log_num_circuits = num_circuits.ilog2(); + assert_eq!(1 << log_num_circuits, num_circuits); + + let mut circuit_batching_randomness: Vec = vec![]; + + for _ in 0..log_num_circuits { + let batching_r = transcript.draw().map_err(|_| VerifierError::FailedToGenerateChallenge)?; + circuit_batching_randomness.push(batching_r); + } + + let tensored_circuit_batching_randomness = + EqFunction::new(circuit_batching_randomness.into()).evaluations(); + // verify all GKR layers but for the last one let num_layers = before_final_layer_proofs.proof.len(); @@ -68,19 +109,26 @@ pub fn verify_gkr< let FinalOpeningClaim { eval_point, openings } = verify_sum_check_intermediate_layers( &before_final_layer_proofs.proof[i], &evaluation_point, - reduced_claim, + &reduced_claims, + &tensored_circuit_batching_randomness, transcript, )?; // generate the random challenge to reduce two claims into a single claim - transcript.reseed(H::hash_elements(&openings)); + for tmp in openings.iter() { + transcript.reseed(H::hash_elements(&tmp)); + } let r_layer = transcript.draw().map_err(|_| VerifierError::FailedToGenerateChallenge)?; - let p0 = openings[0]; - let p1 = openings[1]; - let q0 = openings[2]; - let q1 = openings[3]; - reduced_claim = (p0 + r_layer * (p1 - p0), q0 + r_layer * (q1 - q0)); + for (j, ops) in openings.iter().enumerate() { + let p0 = ops[0]; + let p1 = ops[1]; + let q0 = ops[2]; + let q1 = ops[3]; + + let reduced_claim = (p0 + r_layer * (p1 - p0), q0 + r_layer * (q1 - q0)); + reduced_claims[j] = reduced_claim; + } // collect the randomness used for the current layer let mut ext = eval_point.clone(); @@ -95,7 +143,8 @@ pub fn verify_gkr< final_layer_proof, logup_randomness, &evaluation_point, - reduced_claim, + reduced_claims, + &tensored_circuit_batching_randomness, transcript, ) .map_err(VerifierError::FailedToVerifySumCheck) From 1c594cdf5e48e1db7a4c9d668799329d0f379948 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 23 Sep 2024 20:17:22 +0200 Subject: [PATCH 04/44] wip: POC parallel circuits w/o periodic columns support --- air/src/air/logup_gkr/mod.rs | 2 - prover/Cargo.toml | 1 + prover/src/logup_gkr/mod.rs | 70 ++++++------------------ prover/src/logup_gkr/prover.rs | 36 ++++++------ sumcheck/benches/sum_check_plain.rs | 4 +- sumcheck/src/lib.rs | 22 ++------ sumcheck/src/prover/high_degree.rs | 47 +++++++++------- sumcheck/src/prover/plain.rs | 65 +++++++++++----------- sumcheck/src/verifier/mod.rs | 32 +++++++---- verifier/src/logup_gkr/mod.rs | 9 +-- winterfell/src/tests/logup_gkr_simple.rs | 36 ++++++------ 11 files changed, 141 insertions(+), 183 deletions(-) diff --git a/air/src/air/logup_gkr/mod.rs b/air/src/air/logup_gkr/mod.rs index 27a9d6c78..d1f0f4c49 100644 --- a/air/src/air/logup_gkr/mod.rs +++ b/air/src/air/logup_gkr/mod.rs @@ -4,7 +4,6 @@ // LICENSE file in the root directory of this source tree. use alloc::vec::Vec; -use libc_print::libc_println; use core::marker::PhantomData; use crypto::{ElementHasher, RandomCoin}; @@ -132,7 +131,6 @@ pub trait LogUpGkrEvaluator: Clone + Sync { } let mut eval_point = eval_point; eval_point.push(folding_randomness); - libc_println!("folding ranomnes {:?}", folding_randomness); GkrData::new( LagrangeKernelRandElements::new(eval_point), diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 49f5f71cd..e1a47363e 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -43,6 +43,7 @@ sumcheck = { version = "0.1", path = "../sumcheck", package = "winter-sumcheck", thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false } tracing = { version = "0.1", default-features = false, features = ["attributes"]} utils = { version = "0.9", path = "../utils/core", package = "winter-utils", default-features = false } +libc-print = "0.1.23" [dev-dependencies] criterion = "0.5" diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index bcbbbd33c..94d8bfb55 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -1,14 +1,14 @@ use alloc::vec::Vec; -use core::ops::Add; +use core::{ + fmt::{self, Formatter}, + ops::Add, +}; use air::{EvaluationFrame, GkrData, LogUpGkrEvaluator}; use math::FieldElement; use sumcheck::{EqFunction, MultiLinearPoly, SumCheckProverError}; use tracing::instrument; -use utils::{ - batch_iter_mut, chunks, uninit_vector, ByteReader, ByteWriter, Deserializable, - DeserializationError, Serializable, -}; +use utils::{chunks, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::Trace; @@ -123,8 +123,7 @@ impl EvaluatedCircuit { let periodic_values = evaluator.build_periodic_values(trace.main_segment().num_rows()); let mut input_layer_wires: Vec> = - // Vec::with_capacity(main_trace.main_segment().num_rows() * num_fractions); - vec![vec![]; num_fractions]; + vec![Vec::with_capacity(main_trace.main_segment().num_rows()); num_fractions]; let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols()); let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; @@ -160,7 +159,7 @@ impl EvaluatedCircuit { /// Computes the subsequent layer of the circuit from a given layer. fn compute_next_layer(prev_layers: &[CircuitLayer]) -> Vec> { - let mut next_layers = vec![]; + let mut next_layers = Vec::with_capacity(prev_layers.len() / 2); for prev_layer in prev_layers.iter() { let next_layer_wires = chunks!(prev_layer.wires(), 2) .map(|input_wires| { @@ -215,10 +214,6 @@ where denominators: MultiLinearPoly::from_evaluations(denominators), } } - - fn into_numerators_denominators(self) -> (MultiLinearPoly, MultiLinearPoly) { - (self.numerators, self.denominators) - } } impl Serializable for CircuitLayerPolys @@ -255,6 +250,7 @@ where /// /// Note that a [`Layer`] needs to be first converted to a [`LayerPolys`] before the evaluation of /// the layer can be proved using GKR. +#[derive(Debug)] pub struct CircuitLayer { wires: Vec>, } @@ -288,7 +284,7 @@ impl CircuitLayer { /// /// Hence, addition is defined in the natural way fractions are added together: `a/b + c/d = (ad + /// bc) / bd`. -#[derive(Debug, Clone, Copy)] +#[derive(Clone, Copy)] pub struct CircuitWire { numerator: E, denominator: E, @@ -320,51 +316,17 @@ where } } +impl fmt::Debug for CircuitWire { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{} / {}", self.numerator, self.denominator) + } +} + /// Represents a claim to be proven by a subsequent call to the sum-check protocol. #[derive(Debug)] pub struct GkrClaim { pub evaluation_point: Vec, - pub claimed_evaluation: Vec<(E, E)>, -} - -/// We receive our 4 multilinear polynomials which were evaluated at a random point: -/// `left_numerators` (or `p0`), `right_numerators` (or `p1`), `left_denominators` (or `q0`), and -/// `right_denominators` (or `q1`). We'll call the 4 evaluations at a random point `p0(r)`, `p1(r)`, -/// `q0(r)`, and `q1(r)`, respectively, where `r` is the random point. Note that `r` is a shorthand -/// for a tuple of random values `(r_0, ... r_{l-1})`, where `2^{l + 1}` is the number of wires in -/// the layer. -/// -/// It is important to recall how `p0` and `p1` were constructed (and analogously for `q0` and -/// `q1`). They are the `numerators` layer polynomial (or `p`) evaluations `p(0, r)` and `p(1, r)`, -/// obtained from [`MultiLinearPoly::project_least_significant_variable`]. Hence, `[p0, p1]` form -/// the evaluations of polynomial `p'(x_0) = p(x_0, r)`. Then, the round claim for `numerators`, -/// defined as `p(r_layer, r)`, is simply `p'(r_layer)`. -fn reduce_layer_claim( - left_numerators_opening: E, - right_numerators_opening: E, - left_denominators_opening: E, - right_denominators_opening: E, - r_layer: E, -) -> (E, E) -where - E: FieldElement, -{ - // This is the `numerators` layer polynomial `f(x_0) = numerators(x_0, rx_0, ..., rx_{l-1})`, - // where `rx_0, ..., rx_{l-1}` are the random variables that were sampled during the sumcheck - // round for this layer. - let numerators_univariate = - MultiLinearPoly::from_evaluations(vec![left_numerators_opening, right_numerators_opening]); - - // This is analogous to `numerators_univariate`, but for the `denominators` layer polynomial - let denominators_univariate = MultiLinearPoly::from_evaluations(vec![ - left_denominators_opening, - right_denominators_opening, - ]); - - ( - numerators_univariate.evaluate(&[r_layer]), - denominators_univariate.evaluate(&[r_layer]), - ) + pub claimed_evaluations_per_circuit: Vec<(E, E)>, } /// Builds the auxiliary trace column for the univariate sum-check argument. diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index afa53163e..b848d557b 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -4,9 +4,8 @@ use air::{LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; use sumcheck::{ - sum_check_prove_higher_degree, sumcheck_prove_plain_batched, - BeforeFinalLayerProof, CircuitOutput, EqFunction, FinalLayerProof, GkrCircuitProof, - MultiLinearPoly, SumCheckProof, + sum_check_prove_higher_degree, sumcheck_prove_plain_batched, BeforeFinalLayerProof, + CircuitOutput, EqFunction, FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, }; use tracing::instrument; @@ -69,14 +68,14 @@ pub fn prove_gkr( } // evaluate the GKR fractional sum circuit - let circuit = EvaluatedCircuit::new(main_trace, evaluator, &logup_randomness)?; + let circuits = EvaluatedCircuit::new(main_trace, evaluator, &logup_randomness)?; // include the circuit output as part of the final proof - let output_layers = circuit.output_layers().clone(); + let output_layers = circuits.output_layers().clone(); // run the GKR prover for all layers except the input layer let (before_final_layer_proofs, gkr_claim, tensored_circuit_batching_randomness) = - prove_intermediate_layers(circuit, public_coin)?; + prove_intermediate_layers(circuits, public_coin)?; // build the MLEs of the relevant main trace columns let main_trace_mls = @@ -131,7 +130,7 @@ fn prove_input_layer< // parse the [GkrClaim] resulting from the previous GKR layer let GkrClaim { evaluation_point, - claimed_evaluation: claimed_evaluations, + claimed_evaluations_per_circuit: claimed_evaluations, } = claim; for claimed_evaluation in claimed_evaluations.iter() { transcript.reseed(H::hash_elements(&[claimed_evaluation.0, claimed_evaluation.1])); @@ -201,6 +200,7 @@ fn prove_intermediate_layers< // layer to the verifier. The verifier then replies with a challenge `r` in order to evaluate // `p` and `q` at `r` as multi-linears. let output_layers = circuit.output_layers(); + // TODO: optimize calls to hash function for output_layer in output_layers.into_iter() { let mut evaluations = output_layer.numerators.evaluations().to_vec(); evaluations.extend_from_slice(output_layer.denominators.evaluations()); @@ -214,7 +214,7 @@ fn prove_intermediate_layers< let log_num_circuits = num_circuits.ilog2(); assert_eq!(1 << log_num_circuits, num_circuits); - let mut circuit_batching_randomness: Vec = vec![]; + let mut circuit_batching_randomness: Vec = Vec::with_capacity(log_num_circuits as usize); for _ in 0..log_num_circuits { let batching_r = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; @@ -224,6 +224,7 @@ fn prove_intermediate_layers< let tensored_circuit_batching_randomness = EqFunction::new(circuit_batching_randomness.into()).evaluations(); + let mut layer_proofs: Vec> = Vec::new(); let mut evaluation_point = vec![r]; @@ -253,12 +254,12 @@ fn prove_intermediate_layers< } let r_layer = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; - // reduce the claim - for (j, ops) in proof.openings_claim.openings.iter().enumerate() { - let p0 = ops[0]; - let p1 = ops[1]; - let q0 = ops[2]; - let q1 = ops[3]; + // reduce the claims + for (j, claimed_opening) in proof.openings_claim.openings.iter().enumerate() { + let p0 = claimed_opening[0]; + let p1 = claimed_opening[1]; + let q0 = claimed_opening[2]; + let q1 = claimed_opening[3]; let reduced_claim = (p0 + r_layer * (p1 - p0), q0 + r_layer * (q1 - q0)); claimed_evaluations[j] = reduced_claim; @@ -276,7 +277,7 @@ fn prove_intermediate_layers< BeforeFinalLayerProof { proof: layer_proofs }, GkrClaim { evaluation_point, - claimed_evaluation: claimed_evaluations, + claimed_evaluations_per_circuit: claimed_evaluations, }, tensored_circuit_batching_randomness, )) @@ -296,6 +297,7 @@ fn sum_check_prove_num_rounds_degree_3< transcript: &mut C, ) -> Result, GkrProverError> { // generate challenge to batch two sumchecks + // TODO: optimize hash for claim in claims { transcript.reseed(H::hash_elements(&[claim.0, claim.1])); } @@ -311,8 +313,8 @@ fn sum_check_prove_num_rounds_degree_3< let mut numerators_accross_circuits_1 = vec![]; let mut denominators_accross_circuits_1 = vec![]; - for tu in inner_layers { - let CircuitLayerPolys { numerators, denominators } = tu; + for inner_layer in inner_layers { + let CircuitLayerPolys { numerators, denominators } = inner_layer; let (p0, p1) = numerators.project_least_significant_variable(); let (q0, q1) = denominators.project_least_significant_variable(); numerators_accross_circuits_0.push(p0); diff --git a/sumcheck/benches/sum_check_plain.rs b/sumcheck/benches/sum_check_plain.rs index 14fd859ce..c7f0552bd 100644 --- a/sumcheck/benches/sum_check_plain.rs +++ b/sumcheck/benches/sum_check_plain.rs @@ -11,7 +11,7 @@ use math::{fields::f64::BaseElement, FieldElement}; use rand_utils::{rand_value, rand_vector}; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; -use winter_sumcheck::{sumcheck_prove_plain, EqFunction, MultiLinearPoly}; +use winter_sumcheck::{sumcheck_prove_plain_batched, EqFunction, MultiLinearPoly}; const LOG_POLY_SIZE: [usize; 2] = [18, 20]; @@ -32,7 +32,7 @@ fn sum_check_plain(c: &mut Criterion) { let mut eq = eq; let mut transcript = transcript; - sumcheck_prove_plain(claim, r_batch, p, q, &mut eq, &mut transcript) + sumcheck_prove_plain_batched(claim, r_batch, p, q, &mut eq, &mut transcript) }, BatchSize::SmallInput, ) diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index 4672fdb7d..e411e87e6 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -267,20 +267,6 @@ fn comb_func(p0: E, p1: E, q0: E, q1: E, eq: E, r_batch: E) -> /// The non-linear composition polynomial of the LogUp-GKR protocol specific to the input layer. pub fn evaluate_composition_poly( - eq_at_mu: &[E], - numerators: &[E], - denominators: &[E], - eq_eval: E, - r_sum_check: E, -) -> E { - numerators - .chunks(2) - .zip(denominators.chunks(2).zip(eq_at_mu.iter())) - .map(|(p, (q, eq_w))| *eq_w * comb_func(p[0], p[1], q[0], q[1], eq_eval, r_sum_check)) - .fold(E::ZERO, |acc, x| acc + x) -} -/// The non-linear composition polynomial of the LogUp-GKR protocol specific to the input layer. -pub fn evaluate_composition_poly_2( eq_at_mu: &[E], numerators_zero: &[E], denominators_zero: &[E], @@ -294,8 +280,10 @@ pub fn evaluate_composition_poly_2( .zip( numerators_one .iter() - .zip(denominators_zero.iter().zip(denominators_one.iter().zip(eq_at_mu.iter()))) + .zip(denominators_zero.iter().zip(denominators_one.iter().zip(eq_at_mu.iter()))), ) - .map(|(p0, (p1, (q0, (q1, eq_w))))| {*eq_w * comb_func(*p0, *p1, *q0, *q1, eq_eval, r_sum_check)}) + .map(|(p0, (p1, (q0, (q1, eq_w))))| { + *eq_w * comb_func(*p0, *p1, *q0, *q1, eq_eval, r_sum_check) + }) .fold(E::ZERO, |acc, x| acc + x) -} \ No newline at end of file +} diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index e8652a15c..4664777ab 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -7,14 +7,14 @@ use alloc::vec::Vec; use air::{LogUpGkrEvaluator, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; -use libc_print::libc_println; use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; use super::SumCheckProverError; use crate::{ - evaluate_composition_poly, evaluate_composition_poly_2, CompressedUnivariatePolyEvals, EqFunction, FinalOpeningClaim, MultiLinearPoly, RoundProof, SumCheckProof, SumCheckRoundClaim + evaluate_composition_poly, CompressedUnivariatePolyEvals, EqFunction, FinalOpeningClaim, + MultiLinearPoly, RoundProof, SumCheckProof, SumCheckRoundClaim, }; /// A sum-check prover for the input layer which can accommodate non-linear expressions in @@ -168,8 +168,7 @@ pub fn sum_check_prove_higher_degree< let mut round_proofs = vec![]; - let mut eq_mle = EqFunction::ml_at(evaluation_point.into()); - + let mut eq_mle = EqFunction::ml_at(evaluation_point.clone().into()); // setup first round claim let mut current_round_claim = SumCheckRoundClaim { eval_point: vec![], claim }; @@ -188,7 +187,7 @@ pub fn sum_check_prove_higher_degree< // reseed with the s_0 polynomial coin.reseed(H::hash_elements(&round_poly_coefs.0)); round_proofs.push(RoundProof { round_poly_coefs }); - + //libc_println!("current_round_claim {:?}", current_round_claim); for i in 1..num_rounds { // generate random challenge r_i for the i-th round let round_challenge = @@ -237,12 +236,17 @@ pub fn sum_check_prove_higher_degree< mls.iter_mut() .for_each(|ml| ml.bind_least_significant_variable(round_challenge)); + // fold each periodic multi-linear using the round challenge + periodic_table.bind_least_significant_variable(round_challenge); + let SumCheckRoundClaim { eval_point, claim: _claim } = reduce_claim(&round_proofs[num_rounds - 1], current_round_claim, round_challenge); - let openings :Vec = mls.into_iter().flat_map(|ml| [ml.evaluations()[0], ml.evaluations()[1]]).collect(); + let openings: Vec = mls + .into_iter() + .flat_map(|ml| [ml.evaluations()[0], ml.evaluations()[1]]) + .collect(); - libc_println!("openings prover {:?}", openings); Ok(SumCheckProof { openings_claim: FinalOpeningClaim { eval_point, openings: vec![openings] }, round_proofs, @@ -318,7 +322,7 @@ fn sumcheck_round( let mut evals_periodic_zero_one = vec![E::ZERO; num_periodic]; let mut evals_periodic_one_zero = vec![E::ZERO; num_periodic]; let mut evals_periodic_one_one = vec![E::ZERO; num_periodic]; - + let mut evals_periodic_x_zero = vec![E::ZERO; num_periodic]; let mut evals_periodic_x_one = vec![E::ZERO; num_periodic]; @@ -337,23 +341,26 @@ fn sumcheck_round( (0..1 << num_rounds) .map(|i| { let mut total_evals = vec![E::ZERO; evaluator.max_degree()]; - for (j, ml) in mls.iter().enumerate() { evals_zero_zero[j] = ml.evaluations()[2 * i]; evals_zero_one[j] = ml.evaluations()[2 * i + 1]; - evals_one_zero[j] = ml.evaluations()[2 * (i + num_rounds)]; - evals_one_one[j] = ml.evaluations()[2 * (i + num_rounds) + 1]; + evals_one_zero[j] = ml.evaluations()[2 * i + 2 * (1 << num_rounds)]; + evals_one_one[j] = ml.evaluations()[2 * i + 2 * (1 << num_rounds) + 1]; } - let eq_at_zero = eq_ml.evaluations()[i]; - let eq_at_one = eq_ml.evaluations()[i + ( num_rounds)]; + let eq_at_one = eq_ml.evaluations()[i + (1 << num_rounds)]; // add evaluation of periodic columns periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero_zero); periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_zero_one); - periodic_table.fill_periodic_values_at(2 * (i + num_rounds), &mut evals_periodic_one_zero); - periodic_table.fill_periodic_values_at(2 * (i + num_rounds) + 1, &mut evals_periodic_one_one); - + periodic_table.fill_periodic_values_at( + 2 * i + 2 * (1 << num_rounds), + &mut evals_periodic_one_zero, + ); + periodic_table.fill_periodic_values_at( + 2 * i + 2 * (1 << num_rounds) + 1, + &mut evals_periodic_one_one, + ); // compute the evaluation at 1 evaluator.evaluate_query( @@ -370,11 +377,11 @@ fn sumcheck_round( &mut numerators_one, &mut denominators_one, ); - total_evals[0] = evaluate_composition_poly_2( + total_evals[0] = evaluate_composition_poly( eq_mu, &numerators_zero, - &numerators_one, &denominators_zero, + &numerators_one, &denominators_one, eq_at_one, r_sum_check, @@ -430,11 +437,11 @@ fn sumcheck_round( &mut numerators_one, &mut denominators_one, ); - *e = evaluate_composition_poly_2( + *e = evaluate_composition_poly( eq_mu, &numerators_zero, - &numerators_one, &denominators_zero, + &numerators_one, &denominators_one, eq_x, r_sum_check, diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index db9864899..47d432973 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -4,6 +4,7 @@ // LICENSE file in the root directory of this source tree. use alloc::vec::Vec; + use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; #[cfg(feature = "concurrent")] @@ -16,8 +17,6 @@ use crate::{ SumCheckProof, }; -/* - /// Sum-check prover for non-linear multivariate polynomial of the simple LogUp-GKR. /// /// More specifically, the following function implements the logic of the sum-check prover as @@ -42,7 +41,7 @@ use crate::{ /// q_{\nu-\kappa+1}\left(0, w_{\kappa+1}, \cdots, w_{\nu}\right)\right) /// $$ /// -/// for $k = 1, \cdots, \nu - 1$ +/// for $k = 1, \cdots, \nu - 1$ /// /// Instead of executing two runs of the sum-check protocol, a batching randomness `r_batch` is /// sent by the verifier at the outset in order to batch the two statments. @@ -216,7 +215,7 @@ pub fn sumcheck_prove_plain( eval_point: &[E], ) -> (Vec, Vec) { let mut eval_point_zero = vec![E::ZERO]; - let mut eval_point_one = vec![E::ZERO]; + let mut eval_point_one = vec![E::ONE]; eval_point_zero.extend_from_slice(&eval_point); eval_point_one.extend_from_slice(&eval_point); @@ -251,8 +259,8 @@ fn evaluate_periodic_columns_at( let evaluation_zero = ml.evaluate(point_zero); evaluations_zero.push(evaluation_zero); - let evaluation_one= ml.evaluate(point_one); + let evaluation_one = ml.evaluate(point_one); evaluations_one.push(evaluation_one) } (evaluations_zero, evaluations_one) -} \ No newline at end of file +} diff --git a/verifier/src/logup_gkr/mod.rs b/verifier/src/logup_gkr/mod.rs index fa35e9c54..b2101e2f1 100644 --- a/verifier/src/logup_gkr/mod.rs +++ b/verifier/src/logup_gkr/mod.rs @@ -2,7 +2,6 @@ use alloc::vec::Vec; use air::{Air, LogUpGkrEvaluator}; use crypto::{ElementHasher, RandomCoin}; -use libc_print::libc_println; use math::FieldElement; use sumcheck::{ verify_sum_check_input_layer, verify_sum_check_intermediate_layers, CircuitOutput, EqFunction, FinalOpeningClaim, GkrCircuitProof, SumCheckVerifierError @@ -37,7 +36,7 @@ pub fn verify_gkr< let claim = evaluator.compute_claim(pub_inputs, &logup_randomness); let mut num_acc = E::ZERO; let mut den_acc = E::ONE; - for (circuit_id, (nums, dens)) in + for (_circuit_id, (nums, dens)) in numerators.into_iter().zip(denominators.into_iter()).enumerate() { let mut evaluations = nums.evaluations().to_vec(); @@ -51,11 +50,10 @@ pub fn verify_gkr< // make sure that both denominators are not equal to E::ZERO if q0 == E::ZERO || q1 == E::ZERO { - libc_println!("p0 is zero"); return Err(VerifierError::ZeroOutputDenominator); } - let cur_num = (p0 * q1 + p1 * q0); + let cur_num = p0 * q1 + p1 * q0; let cur_den = q0 * q1; let new_num = num_acc * cur_den + den_acc * cur_num; @@ -73,7 +71,7 @@ pub fn verify_gkr< // reduce the claim let mut reduced_claims = vec![]; - for (circuit_id, (nums, dens)) in + for (_circuit_id, (nums, dens)) in numerators.into_iter().zip(denominators.into_iter()).enumerate() { let p0 = nums.evaluations()[0]; @@ -100,7 +98,6 @@ pub fn verify_gkr< let tensored_circuit_batching_randomness = EqFunction::new(circuit_batching_randomness.into()).evaluations(); - // verify all GKR layers but for the last one let num_layers = before_final_layer_proofs.proof.len(); diff --git a/winterfell/src/tests/logup_gkr_simple.rs b/winterfell/src/tests/logup_gkr_simple.rs index 6c814c948..cfa2743dc 100644 --- a/winterfell/src/tests/logup_gkr_simple.rs +++ b/winterfell/src/tests/logup_gkr_simple.rs @@ -23,7 +23,8 @@ use crate::{ #[test] fn test_logup_gkr() { let aux_trace_width = 1; - let trace = LogUpGkrSimple::new(2_usize.pow(7), aux_trace_width); + let trace = LogUpGkrSimple::new(2_usize.pow(13 + ), aux_trace_width); let prover = LogUpGkrSimpleProver::new(aux_trace_width); let proof = prover.prove(trace).unwrap(); @@ -193,22 +194,14 @@ impl PlainLogUpGkrEval { let committed_2 = LogUpGkrOracle::CurrentRow(2); let committed_3 = LogUpGkrOracle::CurrentRow(3); let committed_4 = LogUpGkrOracle::CurrentRow(4); - let committed_0_next_row = LogUpGkrOracle::NextRow(0); - let committed_1_next_row = LogUpGkrOracle::NextRow(1); - let committed_2_next_row = LogUpGkrOracle::NextRow(2); - let committed_3_next_row = LogUpGkrOracle::NextRow(3); - let committed_4_next_row = LogUpGkrOracle::NextRow(4); + let oracles = vec![ committed_0, committed_1, committed_2, committed_3, committed_4, - committed_0_next_row, - committed_1_next_row, - committed_2_next_row, - committed_3_next_row, - committed_4_next_row, + ]; Self { oracles, _field: PhantomData } } @@ -229,6 +222,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { fn get_num_fractions(&self) -> usize { 4 + //2 } fn max_degree(&self) -> usize { @@ -245,11 +239,6 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { query[3] = frame.current()[3]; query[4] = frame.current()[4]; - query[5] = frame.next()[0]; - query[6] = frame.next()[1]; - query[7] = frame.next()[2]; - query[8] = frame.next()[3]; - query[9] = frame.next()[4]; } fn evaluate_query( @@ -263,9 +252,9 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { F: FieldElement, E: FieldElement + ExtensionOf, { - assert_eq!(numerator.len(), 4); - assert_eq!(denominator.len(), 4); - assert_eq!(query.len(), 10); + //assert_eq!(numerator.len(), 4); + //assert_eq!(denominator.len(), 4); + //assert_eq!(query.len(), 10); numerator[0] = E::from(query[1]); numerator[1] = E::ONE; numerator[2] = E::ONE; @@ -275,6 +264,13 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { denominator[1] = -(rand_values[0] - E::from(query[2])); denominator[2] = -(rand_values[0] - E::from(query[3])); denominator[3] = -(rand_values[0] - E::from(query[4])); + + + + //numerator[0] = -E::ONE; + //numerator[1] = E::ONE; + //denominator[2] = E::ONE; + //denominator[3] = E::ONE; } fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E @@ -296,7 +292,7 @@ impl LogUpGkrSimpleProver { fn new(aux_trace_width: usize) -> Self { Self { aux_trace_width, - options: ProofOptions::new(1, 8, 0, FieldExtension::Quadratic, 2, 1), + options: ProofOptions::new(1, 8, 0, FieldExtension::None, 2, 1), } } } From 3ea525da8e74319db6e5db37db83910c40ce9e64 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 23 Sep 2024 20:33:03 +0200 Subject: [PATCH 05/44] fix: use with_capacity --- prover/src/logup_gkr/prover.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index b848d557b..0035ce92e 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -308,10 +308,10 @@ fn sum_check_prove_num_rounds_degree_3< batched_claims.push(claim) } - let mut numerators_accross_circuits_0 = vec![]; - let mut denominators_accross_circuits_0 = vec![]; - let mut numerators_accross_circuits_1 = vec![]; - let mut denominators_accross_circuits_1 = vec![]; + let mut numerators_accross_circuits_0 = Vec::with_capacity(inner_layers.len()); + let mut denominators_accross_circuits_0 = Vec::with_capacity(inner_layers.len()); + let mut numerators_accross_circuits_1 = Vec::with_capacity(inner_layers.len()); + let mut denominators_accross_circuits_1 = Vec::with_capacity(inner_layers.len()); for inner_layer in inner_layers { let CircuitLayerPolys { numerators, denominators } = inner_layer; From 6fcab1755e0c32dfd2b5280bd3eb78c0ba69689f Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 23 Sep 2024 21:30:21 +0200 Subject: [PATCH 06/44] wip: bring back concurrency --- sumcheck/src/prover/high_degree.rs | 161 +++++++++++++++++++++-------- sumcheck/src/prover/plain.rs | 21 ++-- 2 files changed, 127 insertions(+), 55 deletions(-) diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index 4664777ab..25fe28ddf 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -464,79 +464,130 @@ fn sumcheck_round( .fold( || { ( + vec![E::ZERO; num_mls], + vec![E::ZERO; num_mls], + vec![E::ZERO; num_mls], vec![E::ZERO; num_mls], vec![E::ZERO; num_mls], vec![E::ZERO; num_mls], vec![E::ZERO; num_periodic], vec![E::ZERO; num_periodic], vec![E::ZERO; num_periodic], - vec![E::ZERO; evaluator.get_num_fractions()], - vec![E::ZERO; evaluator.get_num_fractions()], + vec![E::ZERO; num_periodic], + vec![E::ZERO; num_periodic], + vec![E::ZERO; num_periodic], vec![E::ZERO; num_mls], + vec![E::ZERO; num_mls], + vec![E::ZERO; num_periodic], vec![E::ZERO; num_periodic], + vec![E::ZERO; evaluator.get_num_fractions()], + vec![E::ZERO; evaluator.get_num_fractions()], + vec![E::ZERO; evaluator.get_num_fractions()], + vec![E::ZERO; evaluator.get_num_fractions()], vec![E::ZERO; evaluator.max_degree()], ) }, |( - mut evals_zero, - mut evals_one, - mut evals_x, - mut evals_periodic_zero, - mut evals_periodic_one, - mut evals_periodic_x, - mut numerators, - mut denominators, - mut deltas, - mut deltas_periodic, + mut evals_one_zero, + mut evals_one_one, + mut evals_zero_zero, + mut evals_zero_one, + mut evals_x_zero, + mut evals_x_one, + mut evals_periodic_zero_zero, + mut evals_periodic_zero_one, + mut evals_periodic_one_zero, + mut evals_periodic_one_one, + mut evals_periodic_x_zero, + mut evals_periodic_x_one, + mut deltas_zero, + mut deltas_one, + mut deltas_periodic_zero, + mut deltas_periodic_one, + mut numerators_zero, + mut numerators_one, + mut denominators_zero, + mut denominators_one, mut poly_evals, ), i| { for (j, ml) in mls.iter().enumerate() { - evals_zero[j] = ml.evaluations()[i]; - evals_one[j] = ml.evaluations()[i + (1 << num_rounds)]; + evals_zero_zero[j] = ml.evaluations()[2 * i]; + evals_zero_one[j] = ml.evaluations()[2 * i + 1]; + evals_one_zero[j] = ml.evaluations()[2 * i + 2 * (1 << num_rounds)]; + evals_one_one[j] = ml.evaluations()[2 * i + 2 * (1 << num_rounds) + 1]; } let eq_at_zero = eq_ml.evaluations()[i]; let eq_at_one = eq_ml.evaluations()[i + (1 << num_rounds)]; // add evaluation of periodic columns - periodic_table.fill_periodic_values_at(i, &mut evals_periodic_zero); - periodic_table - .fill_periodic_values_at(i + (1 << num_rounds), &mut evals_periodic_one); + periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero_zero); + periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_zero_one); + periodic_table.fill_periodic_values_at( + 2 * i + 2 * (1 << num_rounds), + &mut evals_periodic_one_zero, + ); + periodic_table.fill_periodic_values_at( + 2 * i + 2 * (1 << num_rounds) + 1, + &mut evals_periodic_one_one, + ); // compute the evaluation at 1 evaluator.evaluate_query( - &evals_one, - &evals_periodic_one, + &evals_one_zero, + &evals_periodic_one_zero, log_up_randomness, - &mut numerators, - &mut denominators, + &mut numerators_zero, + &mut denominators_zero, + ); + evaluator.evaluate_query( + &evals_one_one, + &evals_periodic_one_one, + log_up_randomness, + &mut numerators_one, + &mut denominators_one, ); poly_evals[0] += evaluate_composition_poly( eq_mu, - &numerators, - &denominators, + &numerators_zero, + &denominators_zero, + &numerators_one, + &denominators_one, eq_at_one, r_sum_check, ); // compute the evaluations at 2, ..., d_max points for i in 0..num_mls { - deltas[i] = evals_one[i] - evals_zero[i]; - evals_x[i] = evals_one[i]; + deltas_zero[i] = evals_one_zero[i] - evals_zero_zero[i]; + evals_x_zero[i] = evals_one_zero[i]; + deltas_one[i] = evals_one_one[i] - evals_zero_one[i]; + evals_x_one[i] = evals_one_one[i]; } for i in 0..num_periodic { - deltas_periodic[i] = evals_periodic_one[i] - evals_periodic_zero[i]; - evals_periodic_x[i] = evals_periodic_one[i]; + deltas_periodic_zero[i] = + evals_periodic_zero_one[i] - evals_periodic_zero_zero[i]; + evals_periodic_x_zero[i] = evals_periodic_zero_one[i]; + deltas_periodic_one[i] = evals_periodic_one_one[i] - evals_periodic_one_zero[i]; + evals_periodic_x_one[i] = evals_periodic_one_one[i]; } let eq_delta = eq_at_one - eq_at_zero; let mut eq_x = eq_at_one; for e in poly_evals.iter_mut().skip(1) { - evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| { + evals_x_zero.iter_mut().zip(deltas_zero.iter()).for_each(|(evx, delta)| { + *evx += *delta; + }); + evals_periodic_x_zero.iter_mut().zip(deltas_periodic_zero.iter()).for_each( + |(evx, delta)| { + *evx += *delta; + }, + ); + evals_x_one.iter_mut().zip(deltas_one.iter()).for_each(|(evx, delta)| { *evx += *delta; }); - evals_periodic_x.iter_mut().zip(deltas_periodic.iter()).for_each( + evals_periodic_x_one.iter_mut().zip(deltas_periodic_one.iter()).for_each( |(evx, delta)| { *evx += *delta; }, @@ -544,32 +595,52 @@ fn sumcheck_round( eq_x += eq_delta; evaluator.evaluate_query( - &evals_x, - &evals_periodic_x, + &evals_x_zero, + &evals_periodic_x_zero, log_up_randomness, - &mut numerators, - &mut denominators, + &mut numerators_zero, + &mut denominators_zero, + ); + evaluator.evaluate_query( + &evals_x_one, + &evals_periodic_x_one, + log_up_randomness, + &mut numerators_one, + &mut denominators_one, ); *e += evaluate_composition_poly( eq_mu, - &numerators, - &denominators, + &numerators_zero, + &denominators_zero, + &numerators_one, + &denominators_one, eq_x, r_sum_check, ); } ( - evals_zero, - evals_one, - evals_x, - evals_periodic_zero, - evals_periodic_one, - evals_periodic_x, - numerators, - denominators, - deltas, - deltas_periodic, + evals_one_zero, + evals_one_one, + evals_zero_zero, + evals_zero_one, + evals_x_zero, + evals_x_one, + evals_periodic_zero_zero, + evals_periodic_zero_one, + evals_periodic_one_zero, + evals_periodic_one_one, + evals_periodic_x_zero, + evals_periodic_x_one, + + deltas_zero, + deltas_one, + deltas_periodic_zero, + deltas_periodic_one, + numerators_zero, + numerators_one, + denominators_zero, + denominators_one, poly_evals, ) }, diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index 47d432973..9327a14b1 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -10,6 +10,7 @@ use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; use smallvec::smallvec; +use utils::{iter, iter_mut}; use super::SumCheckProverError; use crate::{ @@ -247,10 +248,13 @@ pub fn sumcheck_prove_plain_batched Date: Mon, 23 Sep 2024 22:14:26 +0200 Subject: [PATCH 07/44] batch hashing and move circuit layer to sum-check --- prover/benches/logup_gkr.rs | 32 ++++++- prover/src/logup_gkr/mod.rs | 147 +----------------------------- prover/src/logup_gkr/prover.rs | 17 ++-- sumcheck/src/prover/mod.rs | 158 +++++++++++++++++++++++++++++++++ sumcheck/src/prover/plain.rs | 59 +++++++++++- sumcheck/src/verifier/mod.rs | 5 +- verifier/src/logup_gkr/mod.rs | 15 ++-- 7 files changed, 267 insertions(+), 166 deletions(-) diff --git a/prover/benches/logup_gkr.rs b/prover/benches/logup_gkr.rs index 6e67eddc2..e484955bf 100644 --- a/prover/benches/logup_gkr.rs +++ b/prover/benches/logup_gkr.rs @@ -237,11 +237,11 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { } fn get_num_fractions(&self) -> usize { - 4 + 16 } fn max_degree(&self) -> usize { - 3 + 10 } fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) @@ -262,18 +262,42 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { F: FieldElement, E: FieldElement + ExtensionOf, { - assert_eq!(numerator.len(), 4); - assert_eq!(denominator.len(), 4); + assert_eq!(numerator.len(), 16); + assert_eq!(denominator.len(), 16); assert_eq!(query.len(), 5); numerator[0] = E::from(query[1]); numerator[1] = E::ONE; numerator[2] = E::ONE; numerator[3] = E::ONE; + numerator[4] = E::from(query[1]); + numerator[5] = E::ONE; + numerator[6] = E::ONE; + numerator[7] = E::ONE; + numerator[8] = E::from(query[1]); + numerator[9] = E::ONE; + numerator[10] = E::ONE; + numerator[11] = E::ONE; + numerator[12] = E::from(query[1]); + numerator[13] = E::ONE; + numerator[14] = E::ONE; + numerator[15] = E::ONE; denominator[0] = rand_values[0] - E::from(query[0]); denominator[1] = -(rand_values[0] - E::from(query[2])); denominator[2] = -(rand_values[0] - E::from(query[3])); denominator[3] = -(rand_values[0] - E::from(query[4])); + denominator[4] = rand_values[0] - E::from(query[0]); + denominator[5] = -(rand_values[0] - E::from(query[2])); + denominator[6] = -(rand_values[0] - E::from(query[3])); + denominator[7] = -(rand_values[0] - E::from(query[4])); + denominator[8] = rand_values[0] - E::from(query[0]); + denominator[9] = -(rand_values[0] - E::from(query[2])); + denominator[10] = -(rand_values[0] - E::from(query[3])); + denominator[11] = -(rand_values[0] - E::from(query[4])); + denominator[12] = rand_values[0] - E::from(query[0]); + denominator[13] = -(rand_values[0] - E::from(query[2])); + denominator[14] = -(rand_values[0] - E::from(query[3])); + denominator[15] = -(rand_values[0] - E::from(query[4])); } fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 94d8bfb55..e9736aa01 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -6,7 +6,7 @@ use core::{ use air::{EvaluationFrame, GkrData, LogUpGkrEvaluator}; use math::FieldElement; -use sumcheck::{EqFunction, MultiLinearPoly, SumCheckProverError}; +use sumcheck::{CircuitLayer, CircuitLayerPolys, CircuitWire, EqFunction, MultiLinearPoly, SumCheckProverError}; use tracing::instrument; use utils::{chunks, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; @@ -177,151 +177,6 @@ impl EvaluatedCircuit { } } -// CIRCUIT LAYER POLYS -// =============================================================================================== - -/// Holds a layer of an [`EvaluatedCircuit`] in a representation amenable to proving circuit -/// evaluation using GKR. -#[derive(Clone, Debug)] -pub struct CircuitLayerPolys { - pub numerators: MultiLinearPoly, - pub denominators: MultiLinearPoly, -} - -impl CircuitLayerPolys -where - E: FieldElement, -{ - pub fn from_circuit_layer(layers: &[CircuitLayer]) -> Vec { - let mut result = vec![]; - for layer in layers { - result.push(Self::from_wires(layer.wires.clone())) - } - result - } - - pub fn from_wires(wires: Vec>) -> Self { - let mut numerators = Vec::new(); - let mut denominators = Vec::new(); - - for wire in wires { - numerators.push(wire.numerator); - denominators.push(wire.denominator); - } - - Self { - numerators: MultiLinearPoly::from_evaluations(numerators), - denominators: MultiLinearPoly::from_evaluations(denominators), - } - } -} - -impl Serializable for CircuitLayerPolys -where - E: FieldElement, -{ - fn write_into(&self, target: &mut W) { - let Self { numerators, denominators } = self; - numerators.write_into(target); - denominators.write_into(target); - } -} - -impl Deserializable for CircuitLayerPolys -where - E: FieldElement, -{ - fn read_from(source: &mut R) -> Result { - Ok(Self { - numerators: MultiLinearPoly::read_from(source)?, - denominators: MultiLinearPoly::read_from(source)?, - }) - } -} - -// CIRCUIT LAYER -// =============================================================================================== - -/// Represents a layer in a [`EvaluatedCircuit`]. -/// -/// A layer is made up of a set of `n` wires, where `n` is a power of two. This is the natural -/// circuit representation of a layer, where each consecutive pair of wires are summed to yield a -/// wire in the subsequent layer of an [`EvaluatedCircuit`]. -/// -/// Note that a [`Layer`] needs to be first converted to a [`LayerPolys`] before the evaluation of -/// the layer can be proved using GKR. -#[derive(Debug)] -pub struct CircuitLayer { - wires: Vec>, -} - -impl CircuitLayer { - /// Creates a new [`Layer`] from a set of projective coordinates. - /// - /// Panics if the number of projective coordinates is not a power of two. - pub fn new(wires: Vec>) -> Self { - assert!(wires.len().is_power_of_two()); - - Self { wires } - } - - /// Returns the wires that make up this circuit layer. - pub fn wires(&self) -> &[CircuitWire] { - &self.wires - } - - /// Returns the number of wires in the layer. - pub fn num_wires(&self) -> usize { - self.wires.len() - } -} - -// CIRCUIT WIRE -// =============================================================================================== - -/// Represents a fraction `numerator / denominator` as a pair `(numerator, denominator)`. This is -/// the type for the gates' inputs in [`prover::EvaluatedCircuit`]. -/// -/// Hence, addition is defined in the natural way fractions are added together: `a/b + c/d = (ad + -/// bc) / bd`. -#[derive(Clone, Copy)] -pub struct CircuitWire { - numerator: E, - denominator: E, -} - -impl CircuitWire -where - E: FieldElement, -{ - /// Creates new projective coordinates from a numerator and a denominator. - pub fn new(numerator: E, denominator: E) -> Self { - assert_ne!(denominator, E::ZERO); - - Self { numerator, denominator } - } -} - -impl Add for CircuitWire -where - E: FieldElement, -{ - type Output = Self; - - fn add(self, other: Self) -> Self { - let numerator = self.numerator * other.denominator + other.numerator * self.denominator; - let denominator = self.denominator * other.denominator; - - Self::new(numerator, denominator) - } -} - -impl fmt::Debug for CircuitWire { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "{} / {}", self.numerator, self.denominator) - } -} - /// Represents a claim to be proven by a subsequent call to the sum-check protocol. #[derive(Debug)] pub struct GkrClaim { diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 0035ce92e..63dd02df9 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -200,12 +200,14 @@ fn prove_intermediate_layers< // layer to the verifier. The verifier then replies with a challenge `r` in order to evaluate // `p` and `q` at `r` as multi-linears. let output_layers = circuit.output_layers(); - // TODO: optimize calls to hash function + + let mut total_evaluations = + Vec::with_capacity(output_layers[0].numerators.evaluations().len() * 2); for output_layer in output_layers.into_iter() { - let mut evaluations = output_layer.numerators.evaluations().to_vec(); - evaluations.extend_from_slice(output_layer.denominators.evaluations()); - transcript.reseed(H::hash_elements(&evaluations)); + total_evaluations.extend_from_slice(&output_layer.numerators.evaluations()); + total_evaluations.extend_from_slice(output_layer.denominators.evaluations()); } + transcript.reseed(H::hash_elements(&total_evaluations)); // generate the challenge and reduce [p0, p1, q0, q1] to [pr, qr] let r = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; @@ -224,7 +226,6 @@ fn prove_intermediate_layers< let tensored_circuit_batching_randomness = EqFunction::new(circuit_batching_randomness.into()).evaluations(); - let mut layer_proofs: Vec> = Vec::new(); let mut evaluation_point = vec![r]; @@ -297,10 +298,12 @@ fn sum_check_prove_num_rounds_degree_3< transcript: &mut C, ) -> Result, GkrProverError> { // generate challenge to batch two sumchecks - // TODO: optimize hash + let mut concatenated_claims = Vec::with_capacity(claims.len() * 2); for claim in claims { - transcript.reseed(H::hash_elements(&[claim.0, claim.1])); + concatenated_claims.extend_from_slice(&[claim.0, claim.1]); } + transcript.reseed(H::hash_elements(&concatenated_claims)); + let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; let mut batched_claims = vec![]; for claim in claims { diff --git a/sumcheck/src/prover/mod.rs b/sumcheck/src/prover/mod.rs index 0a25b727a..8267acdb9 100644 --- a/sumcheck/src/prover/mod.rs +++ b/sumcheck/src/prover/mod.rs @@ -4,11 +4,169 @@ // LICENSE file in the root directory of this source tree. mod high_degree; +use core::{fmt::Formatter, ops::Add}; + +use alloc::{fmt, vec::Vec}; pub use high_degree::sum_check_prove_higher_degree; mod plain; +use math::FieldElement; //pub use plain::sumcheck_prove_plain; pub use plain::sumcheck_prove_plain_batched; mod error; pub use error::SumCheckProverError; +use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; + +use crate::MultiLinearPoly; + +// CIRCUIT LAYER POLYS +// =============================================================================================== + +/// Holds a layer of an [`EvaluatedCircuit`] in a representation amenable to proving circuit +/// evaluation using GKR. +#[derive(Clone, Debug)] +pub struct CircuitLayerPolys { + pub numerators: MultiLinearPoly, + pub denominators: MultiLinearPoly, +} + + +impl Serializable for CircuitLayerPolys +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + let Self { numerators, denominators } = self; + numerators.write_into(target); + denominators.write_into(target); + } +} + +impl Deserializable for CircuitLayerPolys +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + numerators: MultiLinearPoly::read_from(source)?, + denominators: MultiLinearPoly::read_from(source)?, + }) + } +} + + +// CIRCUIT LAYER POLYS +// =============================================================================================== + + +impl CircuitLayerPolys +where + E: FieldElement, +{ + pub fn from_circuit_layer(layers: &[CircuitLayer]) -> Vec { + let mut result = vec![]; + for layer in layers { + result.push(Self::from_wires(layer.wires.clone())) + } + result + } + + pub fn from_wires(wires: Vec>) -> Self { + let mut numerators = Vec::new(); + let mut denominators = Vec::new(); + + for wire in wires { + numerators.push(wire.numerator); + denominators.push(wire.denominator); + } + + Self { + numerators: MultiLinearPoly::from_evaluations(numerators), + denominators: MultiLinearPoly::from_evaluations(denominators), + } + } +} + +// CIRCUIT LAYER +// =============================================================================================== + +/// Represents a layer in a [`EvaluatedCircuit`]. +/// +/// A layer is made up of a set of `n` wires, where `n` is a power of two. This is the natural +/// circuit representation of a layer, where each consecutive pair of wires are summed to yield a +/// wire in the subsequent layer of an [`EvaluatedCircuit`]. +/// +/// Note that a [`Layer`] needs to be first converted to a [`LayerPolys`] before the evaluation of +/// the layer can be proved using GKR. +#[derive(Debug)] +pub struct CircuitLayer { + wires: Vec>, +} + +impl CircuitLayer { + /// Creates a new [`Layer`] from a set of projective coordinates. + /// + /// Panics if the number of projective coordinates is not a power of two. + pub fn new(wires: Vec>) -> Self { + assert!(wires.len().is_power_of_two()); + + Self { wires } + } + + /// Returns the wires that make up this circuit layer. + pub fn wires(&self) -> &[CircuitWire] { + &self.wires + } + + /// Returns the number of wires in the layer. + pub fn num_wires(&self) -> usize { + self.wires.len() + } +} + +// CIRCUIT WIRE +// =============================================================================================== + +/// Represents a fraction `numerator / denominator` as a pair `(numerator, denominator)`. This is +/// the type for the gates' inputs in [`prover::EvaluatedCircuit`]. +/// +/// Hence, addition is defined in the natural way fractions are added together: `a/b + c/d = (ad + +/// bc) / bd`. +#[derive(Clone, Copy)] +pub struct CircuitWire { + numerator: E, + denominator: E, +} + +impl CircuitWire +where + E: FieldElement, +{ + /// Creates new projective coordinates from a numerator and a denominator. + pub fn new(numerator: E, denominator: E) -> Self { + assert_ne!(denominator, E::ZERO); + + Self { numerator, denominator } + } +} + +impl Add for CircuitWire +where + E: FieldElement, +{ + type Output = Self; + + fn add(self, other: Self) -> Self { + let numerator = self.numerator * other.denominator + other.numerator * self.denominator; + let denominator = self.denominator * other.denominator; + + Self::new(numerator, denominator) + } +} + +impl fmt::Debug for CircuitWire { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{} / {}", self.numerator, self.denominator) + } +} diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index 9327a14b1..f1394feaf 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -10,7 +10,6 @@ use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; use smallvec::smallvec; -use utils::{iter, iter_mut}; use super::SumCheckProverError; use crate::{ @@ -256,6 +255,7 @@ pub fn sumcheck_prove_plain_batched, ) -> Result, SumCheckVerifierError> { // generate challenge to batch sum-checks + let mut concatenated_claims = Vec::with_capacity(claims.len() * 2); for claim in claims { - transcript.reseed(H::hash_elements(&[claim.0, claim.1])); + concatenated_claims.extend_from_slice(&[claim.0, claim.1]); } + transcript.reseed(H::hash_elements(&concatenated_claims)); + let r_batch: E = transcript .draw() .map_err(|_| SumCheckVerifierError::FailedToGenerateChallenge)?; diff --git a/verifier/src/logup_gkr/mod.rs b/verifier/src/logup_gkr/mod.rs index b2101e2f1..57b18631d 100644 --- a/verifier/src/logup_gkr/mod.rs +++ b/verifier/src/logup_gkr/mod.rs @@ -4,7 +4,8 @@ use air::{Air, LogUpGkrEvaluator}; use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; use sumcheck::{ - verify_sum_check_input_layer, verify_sum_check_intermediate_layers, CircuitOutput, EqFunction, FinalOpeningClaim, GkrCircuitProof, SumCheckVerifierError + verify_sum_check_input_layer, verify_sum_check_intermediate_layers, CircuitOutput, EqFunction, + FinalOpeningClaim, GkrCircuitProof, SumCheckVerifierError, }; /// Verifies the validity of a GKR proof for a LogUp-GKR relation. @@ -34,14 +35,14 @@ pub fn verify_gkr< let CircuitOutput { numerators, denominators } = circuit_outputs; let claim = evaluator.compute_claim(pub_inputs, &logup_randomness); + let mut total_evaluations = Vec::with_capacity(numerators.len() * 4); let mut num_acc = E::ZERO; let mut den_acc = E::ONE; for (_circuit_id, (nums, dens)) in numerators.into_iter().zip(denominators.into_iter()).enumerate() { - let mut evaluations = nums.evaluations().to_vec(); - evaluations.extend_from_slice(&dens.evaluations()); - transcript.reseed(H::hash_elements(&evaluations)); + total_evaluations.extend_from_slice(nums.evaluations()); + total_evaluations.extend_from_slice(dens.evaluations()); let p0 = nums.evaluations()[0]; let p1 = nums.evaluations()[1]; @@ -65,12 +66,12 @@ pub fn verify_gkr< return Err(VerifierError::MismatchingCircuitOutput); } - + transcript.reseed(H::hash_elements(&total_evaluations)); // generate the random challenge to reduce two claims into a single claim let r = transcript.draw().map_err(|_| VerifierError::FailedToGenerateChallenge)?; // reduce the claim - let mut reduced_claims = vec![]; + let mut reduced_claims = vec![]; for (_circuit_id, (nums, dens)) in numerators.into_iter().zip(denominators.into_iter()).enumerate() { @@ -85,7 +86,7 @@ pub fn verify_gkr< reduced_claims.push(reduced_claim) } - let num_circuits = reduced_claims.len(); + let num_circuits = reduced_claims.len(); let log_num_circuits = num_circuits.ilog2(); assert_eq!(1 << log_num_circuits, num_circuits); From e75d47e59d2b08684d51f0a7497c45cf4021fa5a Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 23 Sep 2024 22:50:16 +0200 Subject: [PATCH 08/44] wip: move CircuitLayerPoly to sum-check crate --- prover/src/logup_gkr/prover.rs | 37 ++++----- sumcheck/src/prover/plain.rs | 143 +++++++++++++++++---------------- 2 files changed, 92 insertions(+), 88 deletions(-) diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 63dd02df9..ab39f57b4 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -311,28 +311,29 @@ fn sum_check_prove_num_rounds_degree_3< batched_claims.push(claim) } - let mut numerators_accross_circuits_0 = Vec::with_capacity(inner_layers.len()); - let mut denominators_accross_circuits_0 = Vec::with_capacity(inner_layers.len()); - let mut numerators_accross_circuits_1 = Vec::with_capacity(inner_layers.len()); - let mut denominators_accross_circuits_1 = Vec::with_capacity(inner_layers.len()); - - for inner_layer in inner_layers { - let CircuitLayerPolys { numerators, denominators } = inner_layer; - let (p0, p1) = numerators.project_least_significant_variable(); - let (q0, q1) = denominators.project_least_significant_variable(); - numerators_accross_circuits_0.push(p0); - numerators_accross_circuits_1.push(p1); - denominators_accross_circuits_0.push(q0); - denominators_accross_circuits_1.push(q1) - } + //let mut numerators_accross_circuits_0 = Vec::with_capacity(inner_layers.len()); + //let mut denominators_accross_circuits_0 = Vec::with_capacity(inner_layers.len()); + //let mut numerators_accross_circuits_1 = Vec::with_capacity(inner_layers.len()); + //let mut denominators_accross_circuits_1 = Vec::with_capacity(inner_layers.len()); + + //for inner_layer in inner_layers { + //let CircuitLayerPolys { numerators, denominators } = inner_layer; + //let (p0, p1) = numerators.project_least_significant_variable(); + //let (q0, q1) = denominators.project_least_significant_variable(); + //numerators_accross_circuits_0.push(p0); + //numerators_accross_circuits_1.push(p1); + //denominators_accross_circuits_0.push(q0); + //denominators_accross_circuits_1.push(q1) + //} let proof = sumcheck_prove_plain_batched( &batched_claims, r_batch, - numerators_accross_circuits_0, - numerators_accross_circuits_1, - denominators_accross_circuits_0, - denominators_accross_circuits_1, + inner_layers, + //numerators_accross_circuits_0, + //numerators_accross_circuits_1, + //denominators_accross_circuits_0, + //denominators_accross_circuits_1, eq, tensored_batching_randomness, transcript, diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index f1394feaf..9cec143a4 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -11,7 +11,7 @@ use math::FieldElement; pub use rayon::prelude::*; use smallvec::smallvec; -use super::SumCheckProverError; +use super::{CircuitLayer, CircuitLayerPolys, SumCheckProverError}; use crate::{ comb_func, CompressedUnivariatePolyEvals, FinalOpeningClaim, MultiLinearPoly, RoundProof, SumCheckProof, @@ -224,10 +224,11 @@ pub fn sumcheck_prove_plain>( claims: &[E], r_batch: E, - mut p0_s: Vec>, - mut p1_s: Vec>, - mut q0_s: Vec>, - mut q1_s: Vec>, + mut inner_layers: Vec>, + //mut p0_s: Vec>, + //mut p1_s: Vec>, + //mut q0_s: Vec>, + //mut q1_s: Vec>, eq: &mut MultiLinearPoly, tensored_batching_randomness: &[E], transcript: &mut impl RandomCoin, @@ -240,44 +241,44 @@ pub fn sumcheck_prove_plain_batched Date: Tue, 24 Sep 2024 08:08:05 +0200 Subject: [PATCH 09/44] wip: improve parallelism --- prover/src/logup_gkr/mod.rs | 8 +- prover/src/logup_gkr/prover.rs | 19 ---- sumcheck/src/prover/plain.rs | 187 +++++++++++++++++++++++---------- 3 files changed, 134 insertions(+), 80 deletions(-) diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index e9736aa01..e43b883d2 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -1,14 +1,10 @@ use alloc::vec::Vec; -use core::{ - fmt::{self, Formatter}, - ops::Add, -}; use air::{EvaluationFrame, GkrData, LogUpGkrEvaluator}; use math::FieldElement; -use sumcheck::{CircuitLayer, CircuitLayerPolys, CircuitWire, EqFunction, MultiLinearPoly, SumCheckProverError}; +use sumcheck::{CircuitLayer, CircuitLayerPolys, CircuitWire, EqFunction, SumCheckProverError}; use tracing::instrument; -use utils::{chunks, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +use utils::chunks; use crate::Trace; diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index ab39f57b4..b71e5bab0 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -311,29 +311,10 @@ fn sum_check_prove_num_rounds_degree_3< batched_claims.push(claim) } - //let mut numerators_accross_circuits_0 = Vec::with_capacity(inner_layers.len()); - //let mut denominators_accross_circuits_0 = Vec::with_capacity(inner_layers.len()); - //let mut numerators_accross_circuits_1 = Vec::with_capacity(inner_layers.len()); - //let mut denominators_accross_circuits_1 = Vec::with_capacity(inner_layers.len()); - - //for inner_layer in inner_layers { - //let CircuitLayerPolys { numerators, denominators } = inner_layer; - //let (p0, p1) = numerators.project_least_significant_variable(); - //let (q0, q1) = denominators.project_least_significant_variable(); - //numerators_accross_circuits_0.push(p0); - //numerators_accross_circuits_1.push(p1); - //denominators_accross_circuits_0.push(q0); - //denominators_accross_circuits_1.push(q1) - //} - let proof = sumcheck_prove_plain_batched( &batched_claims, r_batch, inner_layers, - //numerators_accross_circuits_0, - //numerators_accross_circuits_1, - //denominators_accross_circuits_0, - //denominators_accross_circuits_1, eq, tensored_batching_randomness, transcript, diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index 9cec143a4..2d0442d45 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -11,7 +11,7 @@ use math::FieldElement; pub use rayon::prelude::*; use smallvec::smallvec; -use super::{CircuitLayer, CircuitLayerPolys, SumCheckProverError}; +use super::{CircuitLayerPolys, SumCheckProverError}; use crate::{ comb_func, CompressedUnivariatePolyEvals, FinalOpeningClaim, MultiLinearPoly, RoundProof, SumCheckProof, @@ -225,10 +225,6 @@ pub fn sumcheck_prove_plain_batched>, - //mut p0_s: Vec>, - //mut p1_s: Vec>, - //mut q0_s: Vec>, - //mut q1_s: Vec>, eq: &mut MultiLinearPoly, tensored_batching_randomness: &[E], transcript: &mut impl RandomCoin, @@ -248,13 +244,94 @@ pub fn sumcheck_prove_plain_batched Date: Tue, 24 Sep 2024 10:09:04 +0200 Subject: [PATCH 10/44] cleanup: parallel for degree 3 sum-check --- prover/src/logup_gkr/mod.rs | 2 +- prover/src/logup_gkr/prover.rs | 39 ++-- sumcheck/src/prover/high_degree.rs | 1 - sumcheck/src/prover/mod.rs | 2 +- sumcheck/src/prover/plain.rs | 283 ++++++++++++++++++----------- sumcheck/src/verifier/mod.rs | 29 +-- 6 files changed, 217 insertions(+), 139 deletions(-) diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index e43b883d2..fc53bc529 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -4,7 +4,7 @@ use air::{EvaluationFrame, GkrData, LogUpGkrEvaluator}; use math::FieldElement; use sumcheck::{CircuitLayer, CircuitLayerPolys, CircuitWire, EqFunction, SumCheckProverError}; use tracing::instrument; -use utils::chunks; +use utils::{chunks, uninit_vector}; use crate::Trace; diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index b71e5bab0..86d670491 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -4,8 +4,9 @@ use air::{LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; use sumcheck::{ - sum_check_prove_higher_degree, sumcheck_prove_plain_batched, BeforeFinalLayerProof, - CircuitOutput, EqFunction, FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, + sum_check_prove_higher_degree, sumcheck_prove_plain_batched, + sumcheck_prove_plain_batched_serial, BeforeFinalLayerProof, CircuitOutput, EqFunction, + FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, }; use tracing::instrument; @@ -132,9 +133,13 @@ fn prove_input_layer< evaluation_point, claimed_evaluations_per_circuit: claimed_evaluations, } = claim; + + let mut all_claims_concatenated = Vec::with_capacity(claimed_evaluations.len()); for claimed_evaluation in claimed_evaluations.iter() { - transcript.reseed(H::hash_elements(&[claimed_evaluation.0, claimed_evaluation.1])); + all_claims_concatenated.extend_from_slice(&[claimed_evaluation.0, claimed_evaluation.1]); } + transcript.reseed(H::hash_elements(&all_claims_concatenated)); + let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; let mut full_claim = E::ZERO; for (circuit_idx, claimed_evaluation) in claimed_evaluations.iter().enumerate() { @@ -310,15 +315,25 @@ fn sum_check_prove_num_rounds_degree_3< let claim = claim.0 + claim.1 * r_batch; batched_claims.push(claim) } - - let proof = sumcheck_prove_plain_batched( - &batched_claims, - r_batch, - inner_layers, - eq, - tensored_batching_randomness, - transcript, - )?; + let proof = if inner_layers[0].numerators.num_evaluations() >= 512 { + sumcheck_prove_plain_batched( + &batched_claims, + r_batch, + inner_layers, + eq, + tensored_batching_randomness, + transcript, + )? + } else { + sumcheck_prove_plain_batched_serial( + &batched_claims, + r_batch, + inner_layers, + eq, + tensored_batching_randomness, + transcript, + )? + }; Ok(proof) } diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index 25fe28ddf..af554ed7a 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -187,7 +187,6 @@ pub fn sum_check_prove_higher_degree< // reseed with the s_0 polynomial coin.reseed(H::hash_elements(&round_poly_coefs.0)); round_proofs.push(RoundProof { round_poly_coefs }); - //libc_println!("current_round_claim {:?}", current_round_claim); for i in 1..num_rounds { // generate random challenge r_i for the i-th round let round_challenge = diff --git a/sumcheck/src/prover/mod.rs b/sumcheck/src/prover/mod.rs index 8267acdb9..354a84b22 100644 --- a/sumcheck/src/prover/mod.rs +++ b/sumcheck/src/prover/mod.rs @@ -11,7 +11,7 @@ pub use high_degree::sum_check_prove_higher_degree; mod plain; use math::FieldElement; -//pub use plain::sumcheck_prove_plain; +pub use plain::sumcheck_prove_plain_batched_serial; pub use plain::sumcheck_prove_plain_batched; mod error; diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index 2d0442d45..b3fc448c2 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -233,30 +233,28 @@ pub fn sumcheck_prove_plain_batched, +>( + claims: &[E], + r_batch: E, + mut inner_layers: Vec>, + eq: &mut MultiLinearPoly, + tensored_batching_randomness: &[E], + transcript: &mut impl RandomCoin, +) -> Result, SumCheckProverError> { + let mut round_proofs = vec![]; + + let mut challenges = vec![]; + + let mut batched_claim_across_circuits = claims + .iter() + .zip(tensored_batching_randomness.iter()) + .fold(E::ZERO, |acc, (&claim_for_circuit_i, &randomness_for_circuit_i)| { + acc + claim_for_circuit_i * randomness_for_circuit_i + }); + + let num_sum_check_rounds = inner_layers[0].numerators.num_variables() - 1; + for _ in 0..num_sum_check_rounds { + let mut all_round_poly_eval_at_1 = E::ZERO; + let mut all_round_poly_eval_at_2 = E::ZERO; + let mut all_round_poly_eval_at_3 = E::ZERO; + let len = inner_layers[0].numerators.num_evaluations() / 4; + for (inner_layer, batching_randomness) in inner_layers.iter().zip(tensored_batching_randomness) { - let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = (0..len).fold( (E::ZERO, E::ZERO, E::ZERO), |(acc_point_1, acc_point_2, acc_point_3), i| { @@ -387,70 +524,7 @@ pub fn sumcheck_prove_plain_batched, ) -> Result, SumCheckVerifierError> { + let mut all_claims_concatenated = Vec::with_capacity(claim.len()); for claimed_evaluation in claim.iter() { - transcript.reseed(H::hash_elements(&[claimed_evaluation.0, claimed_evaluation.1])); + all_claims_concatenated.extend_from_slice(&[claimed_evaluation.0, claimed_evaluation.1]); } + + transcript.reseed(H::hash_elements(&all_claims_concatenated)); // generate challenge to batch sum-checks let r_batch: E = transcript @@ -108,7 +107,6 @@ pub fn verify_sum_check_input_layer Date: Tue, 24 Sep 2024 10:49:41 +0200 Subject: [PATCH 11/44] fix: add with capacity --- prover/src/logup_gkr/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index fc53bc529..91d8c4c24 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -119,7 +119,7 @@ impl EvaluatedCircuit { let periodic_values = evaluator.build_periodic_values(trace.main_segment().num_rows()); let mut input_layer_wires: Vec> = - vec![Vec::with_capacity(main_trace.main_segment().num_rows()); num_fractions]; + vec![unsafe{uninit_vector(main_trace.main_segment().num_rows())}; num_fractions]; let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols()); let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; @@ -143,7 +143,7 @@ impl EvaluatedCircuit { .zip(denominators.iter()) .zip(input_layer_wires.iter_mut()) .for_each(|((numerator, denominator), circuit_input_layer)| { - circuit_input_layer.push(CircuitWire::new(*numerator, *denominator)) + circuit_input_layer[i] = CircuitWire::new(*numerator, *denominator) }); } From 9fd2f233e9e1b40a028ebce3eff69f454c523503 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Tue, 24 Sep 2024 12:40:06 +0200 Subject: [PATCH 12/44] fix nameing --- prover/src/logup_gkr/mod.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 91d8c4c24..38df7aac3 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -111,7 +111,7 @@ impl EvaluatedCircuit { /// Generates the input layer of the circuit from the main trace columns and some randomness /// provided by the verifier. fn generate_input_layer( - main_trace: &impl Trace, + trace: &impl Trace, evaluator: &impl LogUpGkrEvaluator, log_up_randomness: &[E], ) -> Vec> { @@ -119,15 +119,15 @@ impl EvaluatedCircuit { let periodic_values = evaluator.build_periodic_values(trace.main_segment().num_rows()); let mut input_layer_wires: Vec> = - vec![unsafe{uninit_vector(main_trace.main_segment().num_rows())}; num_fractions]; - let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols()); + vec![unsafe{uninit_vector(trace.main_segment().num_rows())}; num_fractions]; + let mut main_frame = EvaluationFrame::new(trace.main_segment().num_cols()); let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; let mut periodic_values_row = vec![E::BaseField::ZERO; periodic_values.num_columns()]; let mut numerators = vec![E::ZERO; num_fractions]; let mut denominators = vec![E::ZERO; num_fractions]; - for i in 0..main_trace.main_segment().num_rows() { - main_trace.read_main_frame(i, &mut main_frame); + for i in 0..trace.main_segment().num_rows() { + trace.read_main_frame(i, &mut main_frame); periodic_values.fill_periodic_values_at(i, &mut periodic_values_row); evaluator.build_query(&main_frame, &mut query); From b8899c05e3ee3c22942cf10681a2a28ea97a4635 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 25 Sep 2024 14:21:23 +0200 Subject: [PATCH 13/44] after rebase --- sumcheck/src/prover/plain.rs | 5 ++++- sumcheck/src/verifier/mod.rs | 9 ++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index b3fc448c2..49ea5f9f5 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -49,6 +49,9 @@ use crate::{ /// Note that the degree of the non-linear composition polynomial is 3. /// /// [1]: https://eprint.iacr.org/2023/1284 + + +/* #[allow(clippy::too_many_arguments)] pub fn sumcheck_prove_plain>( mut claim: E, @@ -217,7 +220,7 @@ pub fn sumcheck_prove_plain Date: Wed, 25 Sep 2024 16:35:47 +0200 Subject: [PATCH 14/44] added targeted bench for LogUp-GKR --- prover/Cargo.toml | 4 + prover/benches/logup_gkr.rs | 153 +++++------------------------ prover/src/lib.rs | 2 +- prover/src/logup_gkr/mod.rs | 4 +- sumcheck/src/prover/high_degree.rs | 13 ++- sumcheck/src/verifier/mod.rs | 30 ++---- 6 files changed, 44 insertions(+), 162 deletions(-) diff --git a/prover/Cargo.toml b/prover/Cargo.toml index e1a47363e..86aa1f334 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -19,6 +19,10 @@ bench = false name = "logup_gkr" harness = false +[[bench]] +name = "logup_gkr_e2e" +harness = false + [[bench]] name = "row_matrix" harness = false diff --git a/prover/benches/logup_gkr.rs b/prover/benches/logup_gkr.rs index e484955bf..e86c84aef 100644 --- a/prover/benches/logup_gkr.rs +++ b/prover/benches/logup_gkr.rs @@ -6,24 +6,22 @@ use std::{marker::PhantomData, time::Duration, vec::Vec}; use air::{ - Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients, - EvaluationFrame, FieldExtension, LogUpGkrEvaluator, LogUpGkrOracle, ProofOptions, TraceInfo, - TransitionConstraintDegree, + Air, AirContext, Assertion, AuxRandElements, EvaluationFrame, LogUpGkrEvaluator, + LogUpGkrOracle, ProofOptions, TraceInfo, TransitionConstraintDegree, }; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; -use crypto::MerkleTree; +use crypto::RandomCoin; use math::StarkField; use winter_prover::{ crypto::{hashers::Blake3_256, DefaultRandomCoin}, math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, matrix::ColMatrix, - DefaultTraceLde, LogUpGkrConstraintEvaluator, Prover, StarkDomain, Trace, TracePolyTable, + prove_gkr, Trace, }; -const TRACE_LENS: [usize; 2] = [2_usize.pow(18), 2_usize.pow(20)]; -const AUX_TRACE_WIDTH: usize = 2; +const TRACE_LENS: [usize; 4] = [2_usize.pow(18), 2_usize.pow(19), 2_usize.pow(20), 2_usize.pow(21)]; -/// Simple end-to-end benchmark for LogUp-GKR. +/// Simple benchmark for the GKR part of STARK with LogUp-GKR. /// /// The main trace contains `5` columns and the LogUp relation is a simple one where we have: /// @@ -33,27 +31,31 @@ const AUX_TRACE_WIDTH: usize = 2; /// /// Given the above, the benchmark then gives an idea about the minimal overhead due to enabling /// LogUp-GKR. The overhead could be bigger depending on the complexity of the LogUp relation. -fn prove_with_lagrange_kernel(c: &mut Criterion) { - let mut group = c.benchmark_group("prove with Lagrange kernel column"); +fn prove_with_logup_gkr(c: &mut Criterion) { + let mut group = c.benchmark_group("prove LogUp-GKR"); group.sample_size(10); group.measurement_time(Duration::from_secs(20)); for &trace_len in TRACE_LENS.iter() { group.bench_function(BenchmarkId::new("", trace_len), |b| { - let trace = LogUpGkrSimpleTrace::new(trace_len, AUX_TRACE_WIDTH); - let prover = LogUpGkrSimpleProver::new(AUX_TRACE_WIDTH); + let main_trace = LogUpGkrSimpleTrace::new(trace_len); + let evaluator = PlainLogUpGkrEval::new(); b.iter_batched( - || trace.clone(), - |trace| prover.prove(trace).unwrap(), + || (main_trace.clone(), evaluator.clone()), + |(main_trace, evaluator)| { + let mut public_coin = + DefaultRandomCoin::>::new(&[BaseElement::ZERO; 4]); + prove_gkr::(&main_trace, &evaluator, &mut public_coin) + }, BatchSize::SmallInput, ) }); } } -criterion_group!(lagrange_kernel_group, prove_with_lagrange_kernel); -criterion_main!(lagrange_kernel_group); +criterion_group!(logup_gkr_group, prove_with_logup_gkr); +criterion_main!(logup_gkr_group); // LogUpGkrSimple // ================================================================================================= @@ -66,7 +68,7 @@ struct LogUpGkrSimpleTrace { } impl LogUpGkrSimpleTrace { - fn new(trace_len: usize, aux_segment_width: usize) -> Self { + fn new(trace_len: usize) -> Self { assert!(trace_len < u32::MAX.try_into().unwrap()); // we create a column for the table we are looking values into. These are just the integers @@ -103,7 +105,7 @@ impl LogUpGkrSimpleTrace { Self { main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), - info: TraceInfo::new_multi_segment(5, aux_segment_width, 0, trace_len, vec![], true), + info: TraceInfo::new_multi_segment(5, 0, 0, trace_len, vec![], true), } } @@ -237,11 +239,11 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { } fn get_num_fractions(&self) -> usize { - 16 + 4 } fn max_degree(&self) -> usize { - 10 + 3 } fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) @@ -262,42 +264,18 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { F: FieldElement, E: FieldElement + ExtensionOf, { - assert_eq!(numerator.len(), 16); - assert_eq!(denominator.len(), 16); + assert_eq!(numerator.len(), 4); + assert_eq!(denominator.len(), 4); assert_eq!(query.len(), 5); numerator[0] = E::from(query[1]); numerator[1] = E::ONE; numerator[2] = E::ONE; numerator[3] = E::ONE; - numerator[4] = E::from(query[1]); - numerator[5] = E::ONE; - numerator[6] = E::ONE; - numerator[7] = E::ONE; - numerator[8] = E::from(query[1]); - numerator[9] = E::ONE; - numerator[10] = E::ONE; - numerator[11] = E::ONE; - numerator[12] = E::from(query[1]); - numerator[13] = E::ONE; - numerator[14] = E::ONE; - numerator[15] = E::ONE; denominator[0] = rand_values[0] - E::from(query[0]); denominator[1] = -(rand_values[0] - E::from(query[2])); denominator[2] = -(rand_values[0] - E::from(query[3])); denominator[3] = -(rand_values[0] - E::from(query[4])); - denominator[4] = rand_values[0] - E::from(query[0]); - denominator[5] = -(rand_values[0] - E::from(query[2])); - denominator[6] = -(rand_values[0] - E::from(query[3])); - denominator[7] = -(rand_values[0] - E::from(query[4])); - denominator[8] = rand_values[0] - E::from(query[0]); - denominator[9] = -(rand_values[0] - E::from(query[2])); - denominator[10] = -(rand_values[0] - E::from(query[3])); - denominator[11] = -(rand_values[0] - E::from(query[4])); - denominator[12] = rand_values[0] - E::from(query[0]); - denominator[13] = -(rand_values[0] - E::from(query[2])); - denominator[14] = -(rand_values[0] - E::from(query[3])); - denominator[15] = -(rand_values[0] - E::from(query[4])); } fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E @@ -307,86 +285,3 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { E::ZERO } } -// Prover -// ================================================================================================ - -struct LogUpGkrSimpleProver { - aux_trace_width: usize, - options: ProofOptions, -} - -impl LogUpGkrSimpleProver { - fn new(aux_trace_width: usize) -> Self { - Self { - aux_trace_width, - options: ProofOptions::new(1, 8, 0, FieldExtension::Quadratic, 2, 1), - } - } -} - -impl Prover for LogUpGkrSimpleProver { - type BaseField = BaseElement; - type Air = LogUpGkrSimpleAir; - type Trace = LogUpGkrSimpleTrace; - type HashFn = Blake3_256; - type VC = MerkleTree>; - type RandomCoin = DefaultRandomCoin; - type TraceLde> = - DefaultTraceLde; - type ConstraintEvaluator<'a, E: FieldElement> = - LogUpGkrConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; - - fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { - } - - fn options(&self) -> &ProofOptions { - &self.options - } - - fn new_trace_lde( - &self, - trace_info: &TraceInfo, - main_trace: &ColMatrix, - domain: &StarkDomain, - ) -> (Self::TraceLde, TracePolyTable) - where - E: math::FieldElement, - { - DefaultTraceLde::new(trace_info, main_trace, domain) - } - - fn new_evaluator<'a, E>( - &self, - air: &'a Self::Air, - aux_rand_elements: Option>, - composition_coefficients: ConstraintCompositionCoefficients, - ) -> Self::ConstraintEvaluator<'a, E> - where - E: math::FieldElement, - { - LogUpGkrConstraintEvaluator::new(air, aux_rand_elements.unwrap(), composition_coefficients) - } - - fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix - where - E: FieldElement, - { - let main_trace = main_trace.main_segment(); - - let mut columns = Vec::new(); - - let rand_summed = E::from(777_u32); - for _ in 0..self.aux_trace_width { - // building a dummy auxiliary column - let column = main_trace - .get_column(0) - .iter() - .map(|row_val| rand_summed.mul_base(*row_val)) - .collect(); - - columns.push(column); - } - - ColMatrix::new(columns) - } -} diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 9bcc566b6..d9da99dd5 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -88,7 +88,7 @@ pub use trace::{ }; mod logup_gkr; -use logup_gkr::{build_lagrange_column, build_s_column, prove_gkr}; +pub use logup_gkr::{build_lagrange_column, build_s_column, prove_gkr}; mod channel; use channel::ProverChannel; diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 38df7aac3..4c6dd80d1 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -11,7 +11,7 @@ use crate::Trace; mod prover; pub use prover::prove_gkr; #[cfg(feature = "concurrent")] -pub use utils::rayon::{current_num_threads as rayon_num_threads, prelude::*}; +pub use utils::rayon::prelude::*; // EVALUATED CIRCUIT // ================================================================================================ @@ -191,7 +191,7 @@ pub struct GkrClaim { /// product i.e., equation (12) in [1]. This oracle is refered to throughout the codebase as /// the s-column. /// -/// The following function's purpose is two build the column in point 2 given the one in point 1. +/// The following function's purpose is to build the column in point 2 given the one in point 1. /// /// [1]: https://eprint.iacr.org/2023/1284 pub fn build_s_column( diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index af554ed7a..906692b9d 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -395,9 +395,9 @@ fn sumcheck_round( } for i in 0..num_periodic { deltas_periodic_zero[i] = - evals_periodic_zero_one[i] - evals_periodic_zero_zero[i]; - evals_periodic_x_zero[i] = evals_periodic_zero_one[i]; - deltas_periodic_one[i] = evals_periodic_one_one[i] - evals_periodic_one_zero[i]; + evals_periodic_one_zero[i] - evals_periodic_zero_zero[i]; + evals_periodic_x_zero[i] = evals_periodic_one_zero[i]; + deltas_periodic_one[i] = evals_periodic_one_one[i] - evals_periodic_zero_one[i]; evals_periodic_x_one[i] = evals_periodic_one_one[i]; } eq_delta = eq_at_one - eq_at_zero; @@ -566,9 +566,9 @@ fn sumcheck_round( } for i in 0..num_periodic { deltas_periodic_zero[i] = - evals_periodic_zero_one[i] - evals_periodic_zero_zero[i]; - evals_periodic_x_zero[i] = evals_periodic_zero_one[i]; - deltas_periodic_one[i] = evals_periodic_one_one[i] - evals_periodic_one_zero[i]; + evals_periodic_one_zero[i] - evals_periodic_zero_zero[i]; + evals_periodic_x_zero[i] = evals_periodic_one_zero[i]; + deltas_periodic_one[i] = evals_periodic_one_one[i] - evals_periodic_zero_one[i]; evals_periodic_x_one[i] = evals_periodic_one_one[i]; } let eq_delta = eq_at_one - eq_at_zero; @@ -631,7 +631,6 @@ fn sumcheck_round( evals_periodic_one_one, evals_periodic_x_zero, evals_periodic_x_one, - deltas_zero, deltas_one, deltas_periodic_zero, diff --git a/sumcheck/src/verifier/mod.rs b/sumcheck/src/verifier/mod.rs index 0aec63e35..0ac5a1ed1 100644 --- a/sumcheck/src/verifier/mod.rs +++ b/sumcheck/src/verifier/mod.rs @@ -89,10 +89,9 @@ pub fn verify_sum_check_input_layer( - periodic_columns: PeriodicTable, - eval_point: &[E], -) -> Vec { - let mut evaluations = Vec::with_capacity(periodic_columns.num_columns()); - for col in periodic_columns.table() { - let ml = MultiLinearPoly::from_evaluations(col.to_vec()); - let num_variables = ml.num_variables(); - - let evaluation = ml.evaluate(&eval_point[&eval_point.len() - num_variables..]); - evaluations.push(evaluation) - } - evaluations -} fn evaluate_periodic_columns_at( periodic_columns: PeriodicTable, eval_point: &[E], ) -> (Vec, Vec) { - let mut eval_point_zero = vec![E::ZERO]; - let mut eval_point_one = vec![E::ONE]; - eval_point_zero.extend_from_slice(&eval_point); - eval_point_one.extend_from_slice(&eval_point); + let mut eval_point_zero = eval_point.to_vec(); + let mut eval_point_one = eval_point.to_vec(); + eval_point_zero.push(E::ZERO); + eval_point_one.push(E::ONE); let mut evaluations_zero = vec![]; let mut evaluations_one = vec![]; for col in periodic_columns.table() { let ml = MultiLinearPoly::from_evaluations(col.to_vec()); let num_variables = ml.num_variables(); - let point_zero = &eval_point_zero[..num_variables]; - let point_one = &eval_point_one[..num_variables]; + let point_zero = &eval_point_zero[&eval_point_zero.len() - num_variables..]; + let point_one = &eval_point_one[&eval_point_one.len() - num_variables..]; let evaluation_zero = ml.evaluate(point_zero); evaluations_zero.push(evaluation_zero); From 2a51690b1043d8657cf5d4884215206dbba5faf5 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 25 Sep 2024 16:42:56 +0200 Subject: [PATCH 15/44] added targeted bench for LogUp-GKR --- prover/benches/logup_gkr.rs | 32 ++- prover/benches/logup_gkr_e2e.rs | 392 ++++++++++++++++++++++++++++++++ 2 files changed, 420 insertions(+), 4 deletions(-) create mode 100644 prover/benches/logup_gkr_e2e.rs diff --git a/prover/benches/logup_gkr.rs b/prover/benches/logup_gkr.rs index e86c84aef..f5673c8aa 100644 --- a/prover/benches/logup_gkr.rs +++ b/prover/benches/logup_gkr.rs @@ -239,11 +239,11 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { } fn get_num_fractions(&self) -> usize { - 4 + 16 } fn max_degree(&self) -> usize { - 3 + 10 } fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) @@ -264,18 +264,42 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { F: FieldElement, E: FieldElement + ExtensionOf, { - assert_eq!(numerator.len(), 4); - assert_eq!(denominator.len(), 4); + assert_eq!(numerator.len(), 16); + assert_eq!(denominator.len(), 16); assert_eq!(query.len(), 5); numerator[0] = E::from(query[1]); numerator[1] = E::ONE; numerator[2] = E::ONE; numerator[3] = E::ONE; + numerator[4] = E::from(query[1]); + numerator[5] = E::ONE; + numerator[6] = E::ONE; + numerator[7] = E::ONE; + numerator[8] = E::from(query[1]); + numerator[9] = E::ONE; + numerator[10] = E::ONE; + numerator[11] = E::ONE; + numerator[12] = E::from(query[1]); + numerator[13] = E::ONE; + numerator[14] = E::ONE; + numerator[15] = E::ONE; denominator[0] = rand_values[0] - E::from(query[0]); denominator[1] = -(rand_values[0] - E::from(query[2])); denominator[2] = -(rand_values[0] - E::from(query[3])); denominator[3] = -(rand_values[0] - E::from(query[4])); + denominator[4] = rand_values[0] - E::from(query[0]); + denominator[5] = -(rand_values[0] - E::from(query[2])); + denominator[6] = -(rand_values[0] - E::from(query[3])); + denominator[7] = -(rand_values[0] - E::from(query[4])); + denominator[8] = rand_values[0] - E::from(query[0]); + denominator[9] = -(rand_values[0] - E::from(query[2])); + denominator[10] = -(rand_values[0] - E::from(query[3])); + denominator[11] = -(rand_values[0] - E::from(query[4])); + denominator[12] = rand_values[0] - E::from(query[0]); + denominator[13] = -(rand_values[0] - E::from(query[2])); + denominator[14] = -(rand_values[0] - E::from(query[3])); + denominator[15] = -(rand_values[0] - E::from(query[4])); } fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E diff --git a/prover/benches/logup_gkr_e2e.rs b/prover/benches/logup_gkr_e2e.rs new file mode 100644 index 000000000..c15630085 --- /dev/null +++ b/prover/benches/logup_gkr_e2e.rs @@ -0,0 +1,392 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::{marker::PhantomData, time::Duration, vec::Vec}; + +use air::{ + Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients, + EvaluationFrame, FieldExtension, LogUpGkrEvaluator, LogUpGkrOracle, ProofOptions, TraceInfo, + TransitionConstraintDegree, +}; +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use crypto::MerkleTree; +use math::StarkField; +use winter_prover::{ + crypto::{hashers::Blake3_256, DefaultRandomCoin}, + math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, + matrix::ColMatrix, + DefaultTraceLde, LogUpGkrConstraintEvaluator, Prover, StarkDomain, Trace, TracePolyTable, +}; + +const TRACE_LENS: [usize; 2] = [2_usize.pow(18), 2_usize.pow(20)]; +const AUX_TRACE_WIDTH: usize = 2; + +/// Simple end-to-end benchmark for LogUp-GKR. +/// +/// The main trace contains `5` columns and the LogUp relation is a simple one where we have: +/// +/// 1. a table of values from `0` to `trace_len - 1`. +/// 2. a multiplicity column containing the number of look ups for each value in the table. +/// 3. three columns with values contained in the table above. +/// +/// Given the above, the benchmark then gives an idea about the minimal overhead due to enabling +/// LogUp-GKR. The overhead could be bigger depending on the complexity of the LogUp relation. +fn prove_with_logup_gkr(c: &mut Criterion) { + let mut group = c.benchmark_group("prove with LogUp-GKR"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(20)); + + for &trace_len in TRACE_LENS.iter() { + group.bench_function(BenchmarkId::new("", trace_len), |b| { + let trace = LogUpGkrSimpleTrace::new(trace_len, AUX_TRACE_WIDTH); + let prover = LogUpGkrSimpleProver::new(AUX_TRACE_WIDTH); + + b.iter_batched( + || trace.clone(), + |trace| prover.prove(trace).unwrap(), + BatchSize::SmallInput, + ) + }); + } +} + +criterion_group!(lagrange_kernel_group, prove_with_logup_gkr); +criterion_main!(lagrange_kernel_group); + +// LogUpGkrSimple +// ================================================================================================= + +#[derive(Clone, Debug)] +struct LogUpGkrSimpleTrace { + // dummy main trace + main_trace: ColMatrix, + info: TraceInfo, +} + +impl LogUpGkrSimpleTrace { + fn new(trace_len: usize, aux_segment_width: usize) -> Self { + assert!(trace_len < u32::MAX.try_into().unwrap()); + + // we create a column for the table we are looking values into. These are just the integers + // from 0 to `trace_len`. + let table: Vec = + (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); + + // we create three columns that contains values contained in `table`. For simplicity, we + // look up only the values `0` or `1`, we look up the value `1` four times and the value `0` + // `trace_len - 4` times. + + let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_0[i + 4] = BaseElement::ONE; + } + + let mut values_1: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_1[i + 4] = BaseElement::ONE; + } + + let mut values_2: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_2[i + 4] = BaseElement::ONE; + } + + // we create the multiplicity column + let mut multiplicity: Vec = + (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + // we look up the value `1` four times in three columns + multiplicity[1] = BaseElement::new(3 * 4); + // we look up the value `0` `trace_len - 4` in three columns + multiplicity[0] = BaseElement::new(3 * trace_len as u64 - 3 * 4); + + Self { + main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), + info: TraceInfo::new_multi_segment(5, aux_segment_width, 0, trace_len, vec![], true), + } + } + + fn len(&self) -> usize { + self.main_trace.num_rows() + } +} + +impl Trace for LogUpGkrSimpleTrace { + type BaseField = BaseElement; + + fn info(&self) -> &TraceInfo { + &self.info + } + + fn main_segment(&self) -> &ColMatrix { + &self.main_trace + } + + fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame) { + let next_row_idx = row_idx + 1; + self.main_trace.read_row_into(row_idx, frame.current_mut()); + self.main_trace.read_row_into(next_row_idx % self.len(), frame.next_mut()); + } +} + +// AIR +// ================================================================================================= + +struct LogUpGkrSimpleAir { + context: AirContext, +} + +impl Air for LogUpGkrSimpleAir { + type BaseField = BaseElement; + type PublicInputs = (); + + fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { + Self { + context: AirContext::new_multi_segment( + trace_info, + _pub_inputs, + vec![TransitionConstraintDegree::new(1)], + vec![], + 1, + 0, + options, + ), + } + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn evaluate_transition>( + &self, + frame: &EvaluationFrame, + _periodic_values: &[E], + result: &mut [E], + ) { + let current = frame.current()[0]; + let next = frame.next()[0]; + + // increments by 1 + result[0] = next - current - E::ONE; + } + + fn get_assertions(&self) -> Vec> { + vec![Assertion::single(0, 0, BaseElement::ZERO)] + } + + fn evaluate_aux_transition( + &self, + _main_frame: &EvaluationFrame, + _aux_frame: &EvaluationFrame, + _periodic_values: &[F], + _aux_rand_elements: &AuxRandElements, + _result: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + // do nothing + } + + fn get_aux_assertions>( + &self, + _aux_rand_elements: &AuxRandElements, + ) -> Vec> { + vec![] + } + + fn get_logup_gkr_evaluator( + &self, + ) -> impl LogUpGkrEvaluator + { + PlainLogUpGkrEval::new() + } +} + +#[derive(Clone, Default)] +pub struct PlainLogUpGkrEval { + oracles: Vec, + _field: PhantomData, +} + +impl PlainLogUpGkrEval { + pub fn new() -> Self { + let committed_0 = LogUpGkrOracle::CurrentRow(0); + let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_2 = LogUpGkrOracle::CurrentRow(2); + let committed_3 = LogUpGkrOracle::CurrentRow(3); + let committed_4 = LogUpGkrOracle::CurrentRow(4); + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; + Self { oracles, _field: PhantomData } + } +} + +impl LogUpGkrEvaluator for PlainLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + fn get_num_rand_values(&self) -> usize { + 1 + } + + fn get_num_fractions(&self) -> usize { + 16 + } + + fn max_degree(&self) -> usize { + 10 + } + + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) + where + E: FieldElement, + { + query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f) + } + + fn evaluate_query( + &self, + query: &[F], + _periodic_values: &[F], + rand_values: &[E], + numerator: &mut [E], + denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + assert_eq!(numerator.len(), 16); + assert_eq!(denominator.len(), 16); + assert_eq!(query.len(), 5); + numerator[0] = E::from(query[1]); + numerator[1] = E::ONE; + numerator[2] = E::ONE; + numerator[3] = E::ONE; + numerator[4] = E::from(query[1]); + numerator[5] = E::ONE; + numerator[6] = E::ONE; + numerator[7] = E::ONE; + numerator[8] = E::from(query[1]); + numerator[9] = E::ONE; + numerator[10] = E::ONE; + numerator[11] = E::ONE; + numerator[12] = E::from(query[1]); + numerator[13] = E::ONE; + numerator[14] = E::ONE; + numerator[15] = E::ONE; + + denominator[0] = rand_values[0] - E::from(query[0]); + denominator[1] = -(rand_values[0] - E::from(query[2])); + denominator[2] = -(rand_values[0] - E::from(query[3])); + denominator[3] = -(rand_values[0] - E::from(query[4])); + denominator[4] = rand_values[0] - E::from(query[0]); + denominator[5] = -(rand_values[0] - E::from(query[2])); + denominator[6] = -(rand_values[0] - E::from(query[3])); + denominator[7] = -(rand_values[0] - E::from(query[4])); + denominator[8] = rand_values[0] - E::from(query[0]); + denominator[9] = -(rand_values[0] - E::from(query[2])); + denominator[10] = -(rand_values[0] - E::from(query[3])); + denominator[11] = -(rand_values[0] - E::from(query[4])); + denominator[12] = rand_values[0] - E::from(query[0]); + denominator[13] = -(rand_values[0] - E::from(query[2])); + denominator[14] = -(rand_values[0] - E::from(query[3])); + denominator[15] = -(rand_values[0] - E::from(query[4])); + } + + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + E::ZERO + } +} +// Prover +// ================================================================================================ + +struct LogUpGkrSimpleProver { + aux_trace_width: usize, + options: ProofOptions, +} + +impl LogUpGkrSimpleProver { + fn new(aux_trace_width: usize) -> Self { + Self { + aux_trace_width, + options: ProofOptions::new(1, 8, 0, FieldExtension::Quadratic, 2, 1), + } + } +} + +impl Prover for LogUpGkrSimpleProver { + type BaseField = BaseElement; + type Air = LogUpGkrSimpleAir; + type Trace = LogUpGkrSimpleTrace; + type HashFn = Blake3_256; + type VC = MerkleTree>; + type RandomCoin = DefaultRandomCoin; + type TraceLde> = + DefaultTraceLde; + type ConstraintEvaluator<'a, E: FieldElement> = + LogUpGkrConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; + + fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { + } + + fn options(&self) -> &ProofOptions { + &self.options + } + + fn new_trace_lde( + &self, + trace_info: &TraceInfo, + main_trace: &ColMatrix, + domain: &StarkDomain, + ) -> (Self::TraceLde, TracePolyTable) + where + E: math::FieldElement, + { + DefaultTraceLde::new(trace_info, main_trace, domain) + } + + fn new_evaluator<'a, E>( + &self, + air: &'a Self::Air, + aux_rand_elements: Option>, + composition_coefficients: ConstraintCompositionCoefficients, + ) -> Self::ConstraintEvaluator<'a, E> + where + E: math::FieldElement, + { + LogUpGkrConstraintEvaluator::new(air, aux_rand_elements.unwrap(), composition_coefficients) + } + + fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix + where + E: FieldElement, + { + let main_trace = main_trace.main_segment(); + + let mut columns = Vec::new(); + + let rand_summed = E::from(777_u32); + for _ in 0..self.aux_trace_width { + // building a dummy auxiliary column + let column = main_trace + .get_column(0) + .iter() + .map(|row_val| rand_summed.mul_base(*row_val)) + .collect(); + + columns.push(column); + } + + ColMatrix::new(columns) + } +} From 78a2bf904a44e82e0eae524b0c78dcb89da2bc42 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 25 Sep 2024 17:36:11 +0200 Subject: [PATCH 16/44] lower threshold for parallelization --- prover/src/logup_gkr/prover.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 86d670491..11c654aae 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -315,7 +315,7 @@ fn sum_check_prove_num_rounds_degree_3< let claim = claim.0 + claim.1 * r_batch; batched_claims.push(claim) } - let proof = if inner_layers[0].numerators.num_evaluations() >= 512 { + let proof = if inner_layers[0].numerators.num_evaluations() >= 64 { sumcheck_prove_plain_batched( &batched_claims, r_batch, From 162ab07e028bb6b0c961d9a769d574bca5b5a891 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 25 Sep 2024 19:27:09 +0200 Subject: [PATCH 17/44] feat: add parallel MLEs building --- prover/src/logup_gkr/prover.rs | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 11c654aae..f64ae77dd 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -9,6 +9,9 @@ use sumcheck::{ FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, }; use tracing::instrument; +#[cfg(feature = "concurrent")] +use utils::rayon::prelude::*; +use utils::{iter, iter_mut, uninit_vector}; use super::{CircuitLayerPolys, EvaluatedCircuit, GkrClaim, GkrProverError}; use crate::{matrix::ColMatrix, Trace}; @@ -80,7 +83,7 @@ pub fn prove_gkr( // build the MLEs of the relevant main trace columns let main_trace_mls = - build_mls_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; + build_mle_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; // build the periodic table representing periodic columns as multi-linear extensions let periodic_table = evaluator.build_periodic_values(main_trace.main_segment().num_rows()); @@ -165,29 +168,34 @@ fn prove_input_layer< /// Builds the multi-linear extension polynomials needed to run the final sum-check of GKR for /// LogUp-GKR. #[instrument(skip_all)] -fn build_mls_from_main_trace_segment( +fn build_mle_from_main_trace_segment( oracles: &[LogUpGkrOracle], main_trace: &ColMatrix<::BaseField>, ) -> Result>, GkrProverError> { - let mut mls = vec![]; + let mut mls = Vec::with_capacity(oracles.len()); for oracle in oracles { match oracle { LogUpGkrOracle::CurrentRow(index) => { let col = main_trace.get_column(*index); - let values: Vec = col.iter().map(|value| E::from(*value)).collect(); + let values: Vec = iter!(col).map(|value| E::from(*value)).collect(); let ml = MultiLinearPoly::from_evaluations(values); mls.push(ml) }, LogUpGkrOracle::NextRow(index) => { let col = main_trace.get_column(*index); - let mut values: Vec = col.iter().map(|value| E::from(*value)).collect(); - values.rotate_left(1); + + let mut values: Vec = unsafe { uninit_vector(col.len()) }; + values[col.len() - 1] = E::from(col[0]); + iter_mut!(&mut values[..col.len() - 1]) + .enumerate() + .for_each(|(i, value)| *value = E::from(col[i + 1])); let ml = MultiLinearPoly::from_evaluations(values); mls.push(ml) }, }; } + Ok(mls) } From 6d869569db38c087b60f4b451d93060c4c0d5dd6 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 25 Sep 2024 19:33:16 +0200 Subject: [PATCH 18/44] feat: parallel s-column construction --- prover/src/logup_gkr/mod.rs | 121 +++++++++++++++++++++++++++++++----- utils/core/src/iterators.rs | 18 ++++++ 2 files changed, 122 insertions(+), 17 deletions(-) diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 4c6dd80d1..ce36207c6 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -11,7 +11,10 @@ use crate::Trace; mod prover; pub use prover::prove_gkr; #[cfg(feature = "concurrent")] -pub use utils::rayon::prelude::*; +pub use utils::{ + rayon::{current_num_threads as rayon_num_threads, prelude::*}, + {chunks_mut, iter, iter_mut, batch_iter_mut}, +}; // EVALUATED CIRCUIT // ================================================================================================ @@ -119,7 +122,7 @@ impl EvaluatedCircuit { let periodic_values = evaluator.build_periodic_values(trace.main_segment().num_rows()); let mut input_layer_wires: Vec> = - vec![unsafe{uninit_vector(trace.main_segment().num_rows())}; num_fractions]; + vec![unsafe { uninit_vector(trace.main_segment().num_rows()) }; num_fractions]; let mut main_frame = EvaluationFrame::new(trace.main_segment().num_cols()); let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; @@ -195,31 +198,65 @@ pub struct GkrClaim { /// /// [1]: https://eprint.iacr.org/2023/1284 pub fn build_s_column( - main_trace: &impl Trace, + trace: &impl Trace, gkr_data: &GkrData, evaluator: &impl LogUpGkrEvaluator, lagrange_kernel_col: &[E], ) -> Vec { let c = gkr_data.compute_batched_claim(); - let main_segment = main_trace.main_segment(); - let mean = c / E::from(E::BaseField::from(main_segment.num_rows() as u32)); + let num_oracles = evaluator.get_oracles().len(); - let mut result = Vec::with_capacity(main_segment.num_rows()); - let mut last_value = E::ZERO; - result.push(last_value); + let main_segment = trace.main_segment(); + let num_cols = main_segment.num_cols(); + let num_rows = main_segment.num_rows(); + let mean = c / E::from(E::BaseField::from(num_rows as u32)); - let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; - let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols()); + #[cfg(not(feature = "concurrent"))] + let result = { + let mut result = Vec::with_capacity(num_rows); + let mut last_value = E::ZERO; + result.push(last_value); - for (i, item) in lagrange_kernel_col.iter().enumerate().take(main_segment.num_rows() - 1) { - main_trace.read_main_frame(i, &mut main_frame); + let mut query = vec![E::BaseField::ZERO; num_oracles]; + let mut main_frame = EvaluationFrame::new(num_cols); - evaluator.build_query(&main_frame, &mut query); - let cur_value = last_value - mean + gkr_data.compute_batched_query(&query) * *item; + for (i, item) in lagrange_kernel_col.iter().enumerate().take(num_rows - 1) { + trace.read_main_frame(i, &mut main_frame); - result.push(cur_value); - last_value = cur_value; - } + evaluator.build_query(&main_frame, &mut query); + let cur_value = last_value - mean + gkr_data.compute_batched_query(&query) * *item; + + result.push(cur_value); + last_value = cur_value; + } + + result + }; + + #[cfg(feature = "concurrent")] + let result = { + let mut deltas = unsafe { uninit_vector(num_rows) }; + deltas[0] = E::ZERO; + let batch_size = num_rows / rayon_num_threads().next_power_of_two(); + batch_iter_mut!(&mut deltas[1..], batch_size, |batch: &mut [E], batch_offset: usize| { + let mut query = vec![E::BaseField::ZERO; num_oracles]; + let mut main_frame = EvaluationFrame::::new(num_cols); + + for (i, v) in batch.iter_mut().enumerate() { + trace.read_main_frame(i + batch_offset, &mut main_frame); + + evaluator.build_query(&main_frame, &mut query); + *v = gkr_data.compute_batched_query(&query) * lagrange_kernel_col[i + batch_offset] + - mean; + } + }); + + // note that `deltas[0]` is set `0` and thus `deltas` satisfies the conditions for invoking + // the function + let mut cumulative_sum = deltas; + prefix_sum_parallel(&mut cumulative_sum, batch_size); + cumulative_sum + }; result } @@ -236,3 +273,53 @@ pub enum GkrProverError { #[error("failed to generate the random challenge")] FailedToGenerateChallenge, } + +// HELPER +// ================================================================================================= + +/// Computes the cumulative sum, also called prefix sum, of a vector of field elements using +/// parallelism, in place. +/// +/// The function divides the vector into non-overlapping segments and then computes an array of sums +/// for each segment. The function then applies the naive serial implementation to each segment and +/// uses the pre-computed sums in each segment in order to coordinate the results in the different +/// segments. +/// +/// The input vector is of the form `0 || values` where `values` are the values the cumulative sum +/// vector will be computed for, in place. +#[cfg(feature = "concurrent")] +fn prefix_sum_parallel(vector: &mut [E], batch_size: usize) { + let num_partitions = (vector.len() + batch_size - 1) / batch_size; + let mut sum_per_partition = vec![E::ZERO; num_partitions]; + + chunks!(vector, batch_size) + .zip(iter_mut!(sum_per_partition)) + .for_each(|(chunk, entry)| *entry = chunk.iter().fold(E::ZERO, |acc, term| acc + *term)); + + prefix_sum_truncate_right(&mut sum_per_partition); + + chunks_mut!(vector, batch_size) + .zip(iter!(sum_per_partition)) + .for_each(|(chunk, sum_so_far)| prefix_sum_truncate_left(chunk, *sum_so_far)); +} + +/// Computes the cumulative sum of a vector but omits the final cumulative sum. +#[cfg(feature = "concurrent")] +fn prefix_sum_truncate_right(values: &mut [E]) { + let mut sum = E::ZERO; + values.iter_mut().for_each(|v| { + let tmp = *v; + *v = sum; + sum += tmp; + }); +} + +/// Computes the cumulative sum of a vector but omits the initial cumulative sum, namely zero. +#[cfg(feature = "concurrent")] +fn prefix_sum_truncate_left(values: &mut [E], sum: E) { + let mut sum = sum; + values.iter_mut().for_each(|v| { + sum += *v; + *v = sum; + }); +} diff --git a/utils/core/src/iterators.rs b/utils/core/src/iterators.rs index f978acd40..cf8328ae9 100644 --- a/utils/core/src/iterators.rs +++ b/utils/core/src/iterators.rs @@ -133,3 +133,21 @@ macro_rules! chunks { result }}; } + +/// Returns either a regular or a parallel mutable iterator over at most `chunk_size` elements +/// depending on whether `concurrent` feature is enabled. +/// +/// When `concurrent` feature is enabled, creates a parallel iterator; otherwise, creates a +/// regular iterator. +#[macro_export] +macro_rules! chunks_mut { + ($e: expr, $chunk_size: expr) => {{ + #[cfg(feature = "concurrent")] + let result = $e.par_chunks_mut($chunk_size); + + #[cfg(not(feature = "concurrent"))] + let result = $e.chunks_mut($chunk_size); + + result + }}; +} From a976bfba138b996429d0a725b6d4f87f7f76bf69 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 25 Sep 2024 20:13:11 +0200 Subject: [PATCH 19/44] wip: reduce degree in parallel GKR --- prover/src/logup_gkr/prover.rs | 6 +- sumcheck/src/lib.rs | 2 +- sumcheck/src/prover/high_degree.rs | 205 ++++++++--- sumcheck/src/prover/mod.rs | 67 ++++ sumcheck/src/prover/plain.rs | 531 ++++++++++++----------------- sumcheck/src/univariate.rs | 113 +++--- 6 files changed, 489 insertions(+), 435 deletions(-) diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index f64ae77dd..f6b7c1d35 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -251,12 +251,13 @@ fn prove_intermediate_layers< // reduced in terms of the input layer separately in `prove_final_circuit_layer`. for inner_layer in circuit.layers().into_iter().skip(1).rev().skip(1) { // construct the Lagrange kernel evaluated at the previous GKR round randomness - let mut eq_mle = EqFunction::ml_at(evaluation_point.into()); + let mut eq_mle = EqFunction::ml_at(evaluation_point.clone().into()); // run the sumcheck protocol let proof = sum_check_prove_num_rounds_degree_3( inner_layer, &claimed_evaluations, + &evaluation_point, &mut eq_mle, &tensored_circuit_batching_randomness, transcript, @@ -306,6 +307,7 @@ fn sum_check_prove_num_rounds_degree_3< >( inner_layers: Vec>, claims: &[(E, E)], + evaluation_point: &[E], eq: &mut MultiLinearPoly, tensored_batching_randomness: &[E], transcript: &mut C, @@ -325,6 +327,7 @@ fn sum_check_prove_num_rounds_degree_3< } let proof = if inner_layers[0].numerators.num_evaluations() >= 64 { sumcheck_prove_plain_batched( + evaluation_point, &batched_claims, r_batch, inner_layers, @@ -334,6 +337,7 @@ fn sum_check_prove_num_rounds_degree_3< )? } else { sumcheck_prove_plain_batched_serial( + evaluation_point, &batched_claims, r_batch, inner_layers, diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index e411e87e6..d086e46f7 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -23,7 +23,7 @@ mod verifier; pub use verifier::*; mod univariate; -pub use univariate::{CompressedUnivariatePoly, CompressedUnivariatePolyEvals}; +pub use univariate::CompressedUnivariatePoly; mod multilinear; pub use multilinear::{inner_product, EqFunction, MultiLinearPoly}; diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index 906692b9d..5e42ee9ae 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -11,10 +11,10 @@ use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; -use super::SumCheckProverError; +use super::{compute_scaling_down_factors, to_coefficients, SumCheckProverError}; use crate::{ - evaluate_composition_poly, CompressedUnivariatePolyEvals, EqFunction, FinalOpeningClaim, - MultiLinearPoly, RoundProof, SumCheckProof, SumCheckRoundClaim, + evaluate_composition_poly, EqFunction, FinalOpeningClaim, MultiLinearPoly, RoundProof, + SumCheckProof, SumCheckRoundClaim, }; /// A sum-check prover for the input layer which can accommodate non-linear expressions in @@ -148,7 +148,102 @@ use crate::{ /// \right) /// $$ /// +/// +/// We now discuss a further optimization due to [2]. Suppose that we have a sum-check statment of +/// the following form: +/// +/// $$v_0=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{\nu - 1}\right);\left( x_0, \cdots, x_{\nu - 1}\right)\right) +/// C\left( x_0, \cdots, x_{\nu - 1} \right)$$ +/// +/// Then during round $i + 1$ of sum-check, the prover needs to send the following polynomial +/// +/// $$v_{i+1}(X)=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1},\alpha_i, \alpha_{i+1},\cdots\alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// We can write $v_{i+1}(X)$ as: +/// +/// $$v_{i+1}(X)=Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1} \right);\left(r_0,\cdots,r_{i-1}\right)\right) +/// \cdot Eq\left(\alpha_i ;X\right)\sum_{x}Eq\left(\left(\alpha_{i+1},\cdots\alpha_{\nu - 1}\right);\left( x_{i+1}, \cdots x_{\nu - 1}\right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// This means that $v_{i+1}(X)$ is the product of: +/// +/// 1. A constant polynomial: $Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right);\left( r_0, \cdots, r_{i-1} \right) \right)$ +/// 2. A linear polynomial: $Eq\left( \alpha_i ; X \right)$ +/// 3. A high degree polynomial: $\sum_{x} +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$ +/// +/// The advantage of the above decomposition is that the prover when computing $v_{i+1}(X)$ needs to sum over +/// +/// $$ +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// instead of +/// +/// $$ +/// Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1}, \alpha_i, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// which has the advantage of being of degree $1$ less and hence requires less work on the part of the prover. +/// +/// Thus, the prover computes the following polynomial +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// and then scales it in order to get +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right) \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// As the prover computes $v_{i+1}^{'}(X)$ in evaluation form and hence also $v_{i+1}(X)$, this +/// means that due to the degrees being off by $1$, the prover uses the linear factor in order to +/// obtain an additional evaluation point in order to be able to interpolate $v_{i+1}(X)$. +/// More precisely, we can get a root of $$v_{i+1}(X) = 0$$ by solving $$Eq\left( \alpha_i ; X \right) = 0$$ +/// The latter equation has as solution $$\mathsf{r} = \frac{1 - \alpha}{1 - 2\cdot\alpha}$$ +/// which is, except with negligible probability, an evaluation point not in the original +/// evaluation set and hence the prover is able to interpolate $v_{i+1}(X)$ and send it to +/// the verifier. +/// +/// Note that in order to avoid having to compute $\{Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ from $\{Eq\left( \left( \alpha_{i}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i}, \cdots x_{\nu - 1} \right) \right)\}$, or vice versa, we can write +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// as +/// +/// $$v_{i+1}^{'}(X) = \frac{1}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \sum_{x} +/// Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// Thus, $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ can be read from +/// $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{\nu - 1} \right);\left(x_{0}, \cdots x_{\nu - 1} \right) \right)\}$ +/// directly, at the cost of the relation between $v_{i+1}^{'}(X)$ and $v_{i+1}(X)$ becoming +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) \frac{Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right)}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// /// [1]: https://eprint.iacr.org/2023/1284 +/// [2]: https://eprint.iacr.org/2024/108 #[allow(clippy::too_many_arguments)] pub fn sum_check_prove_higher_degree< E: FieldElement, @@ -168,12 +263,13 @@ pub fn sum_check_prove_higher_degree< let mut round_proofs = vec![]; - let mut eq_mle = EqFunction::ml_at(evaluation_point.clone().into()); + let eq_mle = EqFunction::ml_at(evaluation_point.clone().into()); // setup first round claim let mut current_round_claim = SumCheckRoundClaim { eval_point: vec![], claim }; // run the first round of the protocol - let round_poly_evals = sumcheck_round( + let mut round_poly_evals = sumcheck_round( + 0, &tensored_circuits_batching, evaluator, &eq_mle, @@ -182,7 +278,21 @@ pub fn sum_check_prove_higher_degree< &log_up_randomness, r_sum_check, ); - let round_poly_coefs = round_poly_evals.to_poly(current_round_claim.claim); + + // this will hold `Eq((\alpha_0, \cdots, \alpha_{i - 1});(r_0, \cdots, r_{i-1}))` + let mut scaling_up_factor = E::ONE; + // this will hold `Eq((\alpha_{0}, \cdots, \alpha_{i}); (0, \cdots, 0))` for all `i` + let scaling_down_factors = compute_scaling_down_factors(&evaluation_point); + // this is `\alpha_i` above + let mut alpha_i = evaluation_point[0]; + let scaling_down_factor = scaling_down_factors[0]; + let round_poly_coefs = to_coefficients( + &mut round_poly_evals, + current_round_claim.claim, + alpha_i, + scaling_down_factor, + scaling_up_factor, + ); // reseed with the s_0 polynomial coin.reseed(H::hash_elements(&round_poly_coefs.0)); @@ -192,6 +302,11 @@ pub fn sum_check_prove_higher_degree< let round_challenge = coin.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + // update `scaling_up_factor` + alpha_i = evaluation_point[evaluation_point.len() + 1 - mls[0].num_variables()]; + scaling_up_factor *= + round_challenge * alpha_i + (E::ONE - round_challenge) * (E::ONE - alpha_i); + // compute the new reduced round claim let new_round_claim = reduce_claim(&round_proofs[i - 1], current_round_claim, round_challenge); @@ -199,14 +314,14 @@ pub fn sum_check_prove_higher_degree< // fold each multi-linear using the round challenge mls.iter_mut() .for_each(|ml| ml.bind_least_significant_variable(round_challenge)); - eq_mle.bind_least_significant_variable(round_challenge); // fold each periodic multi-linear using the round challenge periodic_table.bind_least_significant_variable(round_challenge); // run the i-th round of the protocol using the folded multi-linears for the new reduced // claim. This basically computes the s_i polynomial. - let round_poly_evals = sumcheck_round( + let mut round_poly_evals = sumcheck_round( + i, &tensored_circuits_batching, evaluator, &eq_mle, @@ -219,7 +334,14 @@ pub fn sum_check_prove_higher_degree< // update the claim current_round_claim = new_round_claim; - let round_poly_coefs = round_poly_evals.to_poly(current_round_claim.claim); + let alpha_i = evaluation_point[i]; + let round_poly_coefs = to_coefficients( + &mut round_poly_evals, + current_round_claim.claim, + alpha_i, + scaling_down_factors[i], + scaling_up_factor, + ); // reseed with the s_i polynomial coin.reseed(H::hash_elements(&round_poly_coefs.0)); @@ -293,15 +415,17 @@ pub fn sum_check_prove_higher_degree< /// the previous one using only additions. This is the purpose of `deltas`, to hold the increments /// added to each multi-linear to compute the evaluation at the next point, and `evals_x` to hold /// the current evaluation at $x$ in $\{2, ... , d_max\}$. +#[allow(clippy::too_many_arguments)] fn sumcheck_round( - eq_mu: &[E], + sum_check_round: usize, + tensored_circuits_batching: &[E], evaluator: &impl LogUpGkrEvaluator::BaseField>, eq_ml: &MultiLinearPoly, mls: &[MultiLinearPoly], periodic_table: &PeriodicTable, log_up_randomness: &[E], r_sum_check: E, -) -> CompressedUnivariatePolyEvals { +) -> Vec { let num_mls = mls.len(); let num_periodic = periodic_table.num_columns(); let num_vars = mls[0].num_variables(); @@ -325,13 +449,10 @@ fn sumcheck_round( let mut evals_periodic_x_zero = vec![E::ZERO; num_periodic]; let mut evals_periodic_x_one = vec![E::ZERO; num_periodic]; - let mut eq_x = E::ZERO; - let mut deltas_zero = vec![E::ZERO; num_mls]; let mut deltas_one = vec![E::ZERO; num_mls]; let mut deltas_periodic_zero = vec![E::ZERO; num_periodic]; let mut deltas_periodic_one = vec![E::ZERO; num_periodic]; - let mut eq_delta = E::ZERO; let mut numerators_zero = vec![E::ZERO; evaluator.get_num_fractions()]; let mut denominators_zero = vec![E::ZERO; evaluator.get_num_fractions()]; @@ -339,7 +460,7 @@ fn sumcheck_round( let mut denominators_one = vec![E::ZERO; evaluator.get_num_fractions()]; (0..1 << num_rounds) .map(|i| { - let mut total_evals = vec![E::ZERO; evaluator.max_degree()]; + let mut total_evals = vec![E::ZERO; evaluator.max_degree() - 1]; for (j, ml) in mls.iter().enumerate() { evals_zero_zero[j] = ml.evaluations()[2 * i]; evals_zero_one[j] = ml.evaluations()[2 * i + 1]; @@ -347,7 +468,6 @@ fn sumcheck_round( evals_one_one[j] = ml.evaluations()[2 * i + 2 * (1 << num_rounds) + 1]; } let eq_at_zero = eq_ml.evaluations()[i]; - let eq_at_one = eq_ml.evaluations()[i + (1 << num_rounds)]; // add evaluation of periodic columns periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero_zero); @@ -361,28 +481,28 @@ fn sumcheck_round( &mut evals_periodic_one_one, ); - // compute the evaluation at 1 + // compute the evaluation at 0 evaluator.evaluate_query( - &evals_one_zero, - &evals_periodic_one_zero, + &evals_zero_zero, + &evals_periodic_zero_zero, log_up_randomness, &mut numerators_zero, &mut denominators_zero, ); evaluator.evaluate_query( - &evals_one_one, - &evals_periodic_one_one, + &evals_zero_one, + &evals_periodic_zero_one, log_up_randomness, &mut numerators_one, &mut denominators_one, ); total_evals[0] = evaluate_composition_poly( - eq_mu, + tensored_circuits_batching, &numerators_zero, &denominators_zero, &numerators_one, &denominators_one, - eq_at_one, + eq_at_zero, r_sum_check, ); @@ -400,8 +520,6 @@ fn sumcheck_round( deltas_periodic_one[i] = evals_periodic_one_one[i] - evals_periodic_zero_one[i]; evals_periodic_x_one[i] = evals_periodic_one_one[i]; } - eq_delta = eq_at_one - eq_at_zero; - eq_x = eq_at_one; for e in total_evals.iter_mut().skip(1) { evals_x_zero.iter_mut().zip(deltas_zero.iter()).for_each(|(evx, delta)| { @@ -420,7 +538,6 @@ fn sumcheck_round( *evx += *delta; }, ); - eq_x += eq_delta; evaluator.evaluate_query( &evals_x_zero, @@ -437,19 +554,19 @@ fn sumcheck_round( &mut denominators_one, ); *e = evaluate_composition_poly( - eq_mu, + tensored_circuits_batching, &numerators_zero, &denominators_zero, &numerators_one, &denominators_one, - eq_x, + eq_at_zero, r_sum_check, ); } total_evals }) - .fold(vec![E::ZERO; evaluator.max_degree()], |mut acc, poly_eval| { + .fold(vec![E::ZERO; evaluator.max_degree() - 1], |mut acc, poly_eval| { acc.iter_mut().zip(poly_eval.iter()).for_each(|(a, b)| { *a += *b; }); @@ -483,7 +600,7 @@ fn sumcheck_round( vec![E::ZERO; evaluator.get_num_fractions()], vec![E::ZERO; evaluator.get_num_fractions()], vec![E::ZERO; evaluator.get_num_fractions()], - vec![E::ZERO; evaluator.max_degree()], + vec![E::ZERO; evaluator.max_degree() - 1], ) }, |( @@ -518,7 +635,6 @@ fn sumcheck_round( } let eq_at_zero = eq_ml.evaluations()[i]; - let eq_at_one = eq_ml.evaluations()[i + (1 << num_rounds)]; // add evaluation of periodic columns periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero_zero); @@ -532,28 +648,28 @@ fn sumcheck_round( &mut evals_periodic_one_one, ); - // compute the evaluation at 1 + // compute the evaluation at 0 evaluator.evaluate_query( - &evals_one_zero, - &evals_periodic_one_zero, + &evals_zero_zero, + &evals_periodic_zero_zero, log_up_randomness, &mut numerators_zero, &mut denominators_zero, ); evaluator.evaluate_query( - &evals_one_one, - &evals_periodic_one_one, + &evals_zero_one, + &evals_periodic_zero_one, log_up_randomness, &mut numerators_one, &mut denominators_one, ); - poly_evals[0] += evaluate_composition_poly( - eq_mu, + total_evals[0] = evaluate_composition_poly( + tensored_circuits_batching, &numerators_zero, &denominators_zero, &numerators_one, &denominators_one, - eq_at_one, + eq_at_zero, r_sum_check, ); @@ -571,8 +687,6 @@ fn sumcheck_round( deltas_periodic_one[i] = evals_periodic_one_one[i] - evals_periodic_zero_one[i]; evals_periodic_x_one[i] = evals_periodic_one_one[i]; } - let eq_delta = eq_at_one - eq_at_zero; - let mut eq_x = eq_at_one; for e in poly_evals.iter_mut().skip(1) { evals_x_zero.iter_mut().zip(deltas_zero.iter()).for_each(|(evx, delta)| { @@ -591,7 +705,6 @@ fn sumcheck_round( *evx += *delta; }, ); - eq_x += eq_delta; evaluator.evaluate_query( &evals_x_zero, @@ -608,12 +721,12 @@ fn sumcheck_round( &mut denominators_one, ); *e += evaluate_composition_poly( - eq_mu, + tensored_circuits_batching, &numerators_zero, &denominators_zero, &numerators_one, &denominators_one, - eq_x, + eq_at_zero, r_sum_check, ); } @@ -645,7 +758,7 @@ fn sumcheck_round( ) .map(|(.., poly_evals)| poly_evals) .reduce( - || vec![E::ZERO; evaluator.max_degree()], + || vec![E::ZERO; evaluator.max_degree() - 1], |mut acc, poly_eval| { acc.iter_mut().zip(poly_eval.iter()).for_each(|(a, b)| { *a += *b; @@ -654,7 +767,7 @@ fn sumcheck_round( }, ); - CompressedUnivariatePolyEvals(evaluations.into()) + evaluations } /// Reduces an old claim to a new claim using the round challenge. diff --git a/sumcheck/src/prover/mod.rs b/sumcheck/src/prover/mod.rs index 354a84b22..ec1d93b1a 100644 --- a/sumcheck/src/prover/mod.rs +++ b/sumcheck/src/prover/mod.rs @@ -5,11 +5,13 @@ mod high_degree; use core::{fmt::Formatter, ops::Add}; +use crate::CompressedUnivariatePoly; use alloc::{fmt, vec::Vec}; pub use high_degree::sum_check_prove_higher_degree; mod plain; +use math::batch_inversion; use math::FieldElement; pub use plain::sumcheck_prove_plain_batched_serial; pub use plain::sumcheck_prove_plain_batched; @@ -170,3 +172,68 @@ impl fmt::Debug for CircuitWire { write!(f, "{} / {}", self.numerator, self.denominator) } } + +// HELPER +// =============================================================================================== + +/// Takes the evaluation of the polynomial $v_{i+1}^{'}(X)$ defined by +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// and computes the interpolation of the $v_{i+1}(X)$ polynomial defined by +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) \frac{Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right)}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// The function returns a `CompressedUnivariatePoly` instead of the full list of coefficients. +fn to_coefficients( + round_poly_evals: &mut [E], + claim: E, + alpha: E, + scaling_down_factor: E, + scaling_up_factor: E, +) -> CompressedUnivariatePoly { + let a = scaling_down_factor; + round_poly_evals.iter_mut().for_each(|e| *e *= scaling_up_factor); + + let mut round_poly_evaluations = Vec::with_capacity(round_poly_evals.len() + 1); + round_poly_evaluations.push(round_poly_evals[0] * compute_weight(alpha, E::ZERO) * a); + round_poly_evaluations.push(claim - round_poly_evaluations[0]); + + for (x, eval) in round_poly_evals.iter().skip(1).enumerate() { + round_poly_evaluations.push(*eval * compute_weight(alpha, E::from(x as u32 + 2)) * a) + } + + let root = (E::ONE - alpha) / (E::ONE - alpha.double()); + + CompressedUnivariatePoly::interpolate_equidistant_points(&round_poly_evaluations, root) +} + +/// Computes +/// +/// $$ +/// Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right) +/// $$ +/// +/// given $(\alpha_0, \cdots, \alpha_{\nu - 1})$ for all $i$ in $0, \cdots, \nu - 1$. +fn compute_scaling_down_factors(gkr_point: &[E]) -> Vec { + let cumulative_product: Vec = gkr_point + .iter() + .scan(E::ONE, |acc, &x| { + *acc *= E::ONE - x; + Some(*acc) + }) + .collect(); + batch_inversion(&cumulative_product) +} + +/// Computes $EQ(x; \alpha)$. +fn compute_weight(alpha: E, x: E) -> E { + x * alpha + (E::ONE - x) * (E::ONE - alpha) +} diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index 49ea5f9f5..576ada231 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -9,13 +9,11 @@ use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; -use smallvec::smallvec; -use super::{CircuitLayerPolys, SumCheckProverError}; -use crate::{ - comb_func, CompressedUnivariatePolyEvals, FinalOpeningClaim, MultiLinearPoly, RoundProof, - SumCheckProof, +use super::{ + compute_scaling_down_factors, to_coefficients, CircuitLayerPolys, SumCheckProverError, }; +use crate::{comb_func, FinalOpeningClaim, MultiLinearPoly, RoundProof, SumCheckProof}; /// Sum-check prover for non-linear multivariate polynomial of the simple LogUp-GKR. /// @@ -41,191 +39,112 @@ use crate::{ /// q_{\nu-\kappa+1}\left(0, w_{\kappa+1}, \cdots, w_{\nu}\right)\right) /// $$ /// -/// for $k = 1, \cdots, \nu - 1$ +/// for $k = 1, \cdots, \nu - 1$ /// /// Instead of executing two runs of the sum-check protocol, a batching randomness `r_batch` is /// sent by the verifier at the outset in order to batch the two statments. /// /// Note that the degree of the non-linear composition polynomial is 3. /// +/// +/// We now discuss a further optimization due to [2]. Suppose that we have a sum-check statment of +/// the following form: +/// +/// $$v_0=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{\nu - 1}\right);\left( x_0, \cdots, x_{\nu - 1}\right)\right) +/// C\left( x_0, \cdots, x_{\nu - 1} \right)$$ +/// +/// Then during round $i + 1$ of sum-check, the prover needs to send the following polynomial +/// +/// $$v_{i+1}(X)=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1},\alpha_i, \alpha_{i+1},\cdots\alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// We can write $v_{i+1}(X)$ as: +/// +/// $$v_{i+1}(X)=Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1} \right);\left(r_0,\cdots,r_{i-1}\right)\right) +/// \cdot Eq\left(\alpha_i ;X\right)\sum_{x}Eq\left(\left(\alpha_{i+1},\cdots\alpha_{\nu - 1}\right);\left( x_{i+1}, \cdots x_{\nu - 1}\right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// This means that $v_{i+1}(X)$ is the product of: +/// +/// 1. A constant polynomial: $Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right);\left( r_0, \cdots, r_{i-1} \right) \right)$ +/// 2. A linear polynomial: $Eq\left( \alpha_i ; X \right)$ +/// 3. A high degree polynomial: $\sum_{x} +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$ +/// +/// The advantage of the above decomposition is that the prover when computing $v_{i+1}(X)$ needs to sum over +/// +/// $$ +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// instead of +/// +/// $$ +/// Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1}, \alpha_i, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// which has the advantage of being of degree $1$ less and hence requires less work on the part of the prover. +/// +/// Thus, the prover computes the following polynomial +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// and then scales it in order to get +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right) \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// As the prover computes $v_{i+1}^{'}(X)$ in evaluation form and hence also $v_{i+1}(X)$, this +/// means that due to the degrees being off by $1$, the prover uses the linear factor in order to +/// obtain an additional evaluation point in order to be able to interpolate $v_{i+1}(X)$. +/// More precisely, we can get a root of $$v_{i+1}(X) = 0$$ by solving $$Eq\left( \alpha_i ; X \right) = 0$$ +/// The latter equation has as solution $$\mathsf{r} = \frac{1 - \alpha}{1 - 2\cdot\alpha}$$ +/// which is, except with negligible probability, an evaluation point not in the original +/// evaluation set and hence the prover is able to interpolate $v_{i+1}(X)$ and send it to +/// the verifier. +/// +/// Note that in order to avoid having to compute $\{Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ from $\{Eq\left( \left( \alpha_{i}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i}, \cdots x_{\nu - 1} \right) \right)\}$, or vice versa, we can write +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// as +/// +/// $$v_{i+1}^{'}(X) = \frac{1}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \sum_{x} +/// Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// Thus, $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ can be read from +/// $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{\nu - 1} \right);\left(x_{0}, \cdots x_{\nu - 1} \right) \right)\}$ +/// directly, at the cost of the relation between $v_{i+1}^{'}(X)$ and $v_{i+1}(X)$ becoming +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) \frac{Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right)}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// /// [1]: https://eprint.iacr.org/2023/1284 - - -/* -#[allow(clippy::too_many_arguments)] -pub fn sumcheck_prove_plain>( - mut claim: E, - r_batch: E, - p: MultiLinearPoly, - q: MultiLinearPoly, - eq: &mut MultiLinearPoly, - transcript: &mut impl RandomCoin, -) -> Result, SumCheckProverError> { - let mut round_proofs = vec![]; - - let mut challenges = vec![]; - - // construct the vector of multi-linear polynomials - let (mut p0, mut p1) = p.project_least_significant_variable(); - let (mut q0, mut q1) = q.project_least_significant_variable(); - - for _ in 0..p0.num_variables() { - let len = p0.num_evaluations() / 2; - - #[cfg(not(feature = "concurrent"))] - let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = (0..len).fold( - (E::ZERO, E::ZERO, E::ZERO), - |(acc_point_1, acc_point_2, acc_point_3), i| { - let round_poly_eval_at_1 = comb_func( - p0[len + i], - p1[len + i], - q0[len + i], - q1[len + i], - eq[len + i], - r_batch, - ); - - let p0_delta = p0[len + i] - p0[i]; - let p1_delta = p1[len + i] - p1[i]; - let q0_delta = q0[len + i] - q0[i]; - let q1_delta = q1[len + i] - q1[i]; - let eq_delta = eq[len + i] - eq[i]; - - let mut p0_eval_at_x = p0[len + i] + p0_delta; - let mut p1_eval_at_x = p1[len + i] + p1_delta; - let mut q0_eval_at_x = q0[len + i] + q0_delta; - let mut q1_eval_at_x = q1[len + i] + q1_delta; - let mut eq_evx = eq[len + i] + eq_delta; - let round_poly_eval_at_2 = comb_func( - p0_eval_at_x, - p1_eval_at_x, - q0_eval_at_x, - q1_eval_at_x, - eq_evx, - r_batch, - ); - - p0_eval_at_x += p0_delta; - p1_eval_at_x += p1_delta; - q0_eval_at_x += q0_delta; - q1_eval_at_x += q1_delta; - eq_evx += eq_delta; - let round_poly_eval_at_3 = comb_func( - p0_eval_at_x, - p1_eval_at_x, - q0_eval_at_x, - q1_eval_at_x, - eq_evx, - r_batch, - ); - - ( - round_poly_eval_at_1 + acc_point_1, - round_poly_eval_at_2 + acc_point_2, - round_poly_eval_at_3 + acc_point_3, - ) - }, - ); - - #[cfg(feature = "concurrent")] - let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = (0..len) - .into_par_iter() - .fold( - || (E::ZERO, E::ZERO, E::ZERO), - |(a, b, c), i| { - let round_poly_eval_at_1 = comb_func( - p0[len + i], - p1[len + i], - q0[len + i], - q1[len + i], - eq[len + i], - r_batch, - ); - - let p0_delta = p0[len + i] - p0[i]; - let p1_delta = p1[len + i] - p1[i]; - let q0_delta = q0[len + i] - q0[i]; - let q1_delta = q1[len + i] - q1[i]; - let eq_delta = eq[len + i] - eq[i]; - - let mut p0_eval_at_x = p0[len + i] + p0_delta; - let mut p1_eval_at_x = p1[len + i] + p1_delta; - let mut q0_eval_at_x = q0[len + i] + q0_delta; - let mut q1_eval_at_x = q1[len + i] + q1_delta; - let mut eq_evx = eq[len + i] + eq_delta; - let round_poly_eval_at_2 = comb_func( - p0_eval_at_x, - p1_eval_at_x, - q0_eval_at_x, - q1_eval_at_x, - eq_evx, - r_batch, - ); - - p0_eval_at_x += p0_delta; - p1_eval_at_x += p1_delta; - q0_eval_at_x += q0_delta; - q1_eval_at_x += q1_delta; - eq_evx += eq_delta; - let round_poly_eval_at_3 = comb_func( - p0_eval_at_x, - p1_eval_at_x, - q0_eval_at_x, - q1_eval_at_x, - eq_evx, - r_batch, - ); - (round_poly_eval_at_1 + a, round_poly_eval_at_2 + b, round_poly_eval_at_3 + c) - }, - ) - .reduce( - || (E::ZERO, E::ZERO, E::ZERO), - |(a0, b0, c0), (a1, b1, c1)| (a0 + a1, b0 + b1, c0 + c1), - ); - - let evals = smallvec![round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3]; - let compressed_round_poly_evals = CompressedUnivariatePolyEvals(evals); - let compressed_round_poly = compressed_round_poly_evals.to_poly(claim); - - // reseed with the s_i polynomial - transcript.reseed(H::hash_elements(&compressed_round_poly.0)); - let round_proof = RoundProof { - round_poly_coefs: compressed_round_poly.clone(), - }; - - let round_challenge = - transcript.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; - - // fold each multi-linear using the round challenge - p0.bind_least_significant_variable(round_challenge); - p1.bind_least_significant_variable(round_challenge); - q0.bind_least_significant_variable(round_challenge); - q1.bind_least_significant_variable(round_challenge); - eq.bind_least_significant_variable(round_challenge); - - // compute the new reduced round claim - claim = compressed_round_poly.evaluate_using_claim(&claim, &round_challenge); - - round_proofs.push(round_proof); - challenges.push(round_challenge); - } - - Ok(SumCheckProof { - openings_claim: FinalOpeningClaim { - eval_point: challenges, - openings: vec![p0[0], p1[0], q0[0], q1[0]], - }, - round_proofs, - }) -} - - -*/ - - - +/// [2]: https://eprint.iacr.org/2024/108 pub fn sumcheck_prove_plain_batched>( claims: &[E], + gkr_point: &[E], r_batch: E, mut inner_layers: Vec>, eq: &mut MultiLinearPoly, @@ -243,9 +162,12 @@ pub fn sumcheck_prove_plain_batched, >( claims: &[E], + gkr_point: &[E], r_batch: E, mut inner_layers: Vec>, eq: &mut MultiLinearPoly, @@ -462,82 +366,63 @@ pub fn sumcheck_prove_plain_batched_serial< acc + claim_for_circuit_i * randomness_for_circuit_i }); + let scaling_down_factors = compute_scaling_down_factors(gkr_point); + let mut scaling_up_factor = E::ONE; + let num_sum_check_rounds = inner_layers[0].numerators.num_variables() - 1; - for _ in 0..num_sum_check_rounds { - let mut all_round_poly_eval_at_1 = E::ZERO; + for i in 0..num_sum_check_rounds { + let mut all_round_poly_eval_at_0 = E::ZERO; let mut all_round_poly_eval_at_2 = E::ZERO; - let mut all_round_poly_eval_at_3 = E::ZERO; let len = inner_layers[0].numerators.num_evaluations() / 4; for (inner_layer, batching_randomness) in inner_layers.iter().zip(tensored_batching_randomness) { - let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = (0..len).fold( - (E::ZERO, E::ZERO, E::ZERO), - |(acc_point_1, acc_point_2, acc_point_3), i| { - let p0_i_0 = inner_layer.numerators[2 * i]; - let p0_i_1 = inner_layer.numerators[2 * i + 1]; - let p1_i_0 = inner_layer.numerators[2 * (i + len)]; - let p1_i_1 = inner_layer.numerators[2 * (i + len) + 1]; - let q0_i_0 = inner_layer.denominators[2 * i]; - let q0_i_1 = inner_layer.denominators[2 * i + 1]; - let q1_i_0 = inner_layer.denominators[2 * (i + len)]; - let q1_i_1 = inner_layer.denominators[2 * (i + len) + 1]; - let round_poly_eval_at_1 = - comb_func(p1_i_0, p1_i_1, q1_i_0, q1_i_1, eq[i + len], r_batch); + let (round_poly_eval_at_0, round_poly_eval_at_2) = + (0..len).fold((E::ZERO, E::ZERO), |(acc_point_0, acc_point_2), k| { + let p0_i_0 = inner_layer.numerators[2 * k]; + let p0_i_1 = inner_layer.numerators[2 * k + 1]; + let p1_i_0 = inner_layer.numerators[2 * (k + len)]; + let p1_i_1 = inner_layer.numerators[2 * (k + len) + 1]; + let q0_i_0 = inner_layer.denominators[2 * k]; + let q0_i_1 = inner_layer.denominators[2 * k + 1]; + let q1_i_0 = inner_layer.denominators[2 * (k + len)]; + let q1_i_1 = inner_layer.denominators[2 * (k + len) + 1]; + let round_poly_eval_at_0 = + comb_func(p0_i_0, p0_i_1, q0_i_0, q0_i_1, eq[k], r_batch); let p0_delta = p1_i_0 - p0_i_0; let p1_delta = p1_i_1 - p0_i_1; let q0_delta = q1_i_0 - q0_i_0; let q1_delta = q1_i_1 - q0_i_1; - let eq_delta = eq[i + len] - eq[i]; - let mut p0_eval_at_x = p1_i_0 + p0_delta; - let mut p1_eval_at_x = p1_i_1 + p1_delta; - let mut q0_eval_at_x = q1_i_0 + q0_delta; - let mut q1_eval_at_x = q1_i_1 + q1_delta; - let mut eq_evx = eq[i + len] + eq_delta; + let p0_eval_at_2 = p1_i_0 + p0_delta; + let p1_eval_at_2 = p1_i_1 + p1_delta; + let q0_eval_at_2 = q1_i_0 + q0_delta; + let q1_eval_at_2 = q1_i_1 + q1_delta; let round_poly_eval_at_2 = comb_func( - p0_eval_at_x, - p1_eval_at_x, - q0_eval_at_x, - q1_eval_at_x, - eq_evx, + p0_eval_at_2, + p1_eval_at_2, + q0_eval_at_2, + q1_eval_at_2, + eq[k], r_batch, ); - p0_eval_at_x += p0_delta; - p1_eval_at_x += p1_delta; - q0_eval_at_x += q0_delta; - q1_eval_at_x += q1_delta; - eq_evx += eq_delta; - let round_poly_eval_at_3 = comb_func( - p0_eval_at_x, - p1_eval_at_x, - q0_eval_at_x, - q1_eval_at_x, - eq_evx, - r_batch, - ); + (round_poly_eval_at_0 + acc_point_0, round_poly_eval_at_2 + acc_point_2) + }); - ( - round_poly_eval_at_1 + acc_point_1, - round_poly_eval_at_2 + acc_point_2, - round_poly_eval_at_3 + acc_point_3, - ) - }, - ); - - all_round_poly_eval_at_1 += round_poly_eval_at_1 * *batching_randomness; + all_round_poly_eval_at_0 += round_poly_eval_at_0 * *batching_randomness; all_round_poly_eval_at_2 += round_poly_eval_at_2 * *batching_randomness; - all_round_poly_eval_at_3 += round_poly_eval_at_3 * *batching_randomness; } - - let evals = - smallvec![all_round_poly_eval_at_1, all_round_poly_eval_at_2, all_round_poly_eval_at_3]; - let compressed_round_poly_evals = CompressedUnivariatePolyEvals(evals); - let compressed_round_poly = - compressed_round_poly_evals.to_poly(batched_claim_across_circuits); + let alpha_i = gkr_point[i]; + let compressed_round_poly = to_coefficients( + &mut [all_round_poly_eval_at_0, all_round_poly_eval_at_2], + batched_claim_across_circuits, + alpha_i, + scaling_down_factors[i], + scaling_up_factor, + ); // reseed with the s_i polynomial transcript.reseed(H::hash_elements(&compressed_round_poly.0)); @@ -553,7 +438,9 @@ pub fn sumcheck_prove_plain_batched_serial< inner_layer.numerators.bind_least_significant_variable(round_challenge); inner_layer.denominators.bind_least_significant_variable(round_challenge); } - eq.bind_least_significant_variable(round_challenge); + // update the scaling up factor + scaling_up_factor *= + round_challenge * alpha_i + (E::ONE - round_challenge) * (E::ONE - alpha_i); // compute the new reduced round claim batched_claim_across_circuits = compressed_round_poly diff --git a/sumcheck/src/univariate.rs b/sumcheck/src/univariate.rs index 082a4daf9..ebc6dfa47 100644 --- a/sumcheck/src/univariate.rs +++ b/sumcheck/src/univariate.rs @@ -48,65 +48,29 @@ impl CompressedUnivariatePoly { // evaluate polynom::eval(&complete_coefficients, *challenge) } -} - -impl Serializable for CompressedUnivariatePoly { - fn write_into(&self, target: &mut W) { - let vector: Vec = self.0.clone().into_vec(); - vector.write_into(target); - } -} -impl Deserializable for CompressedUnivariatePoly -where - E: FieldElement, -{ - fn read_from(source: &mut R) -> Result { - let vector: Vec = Vec::::read_from(source)?; - Ok(Self(vector.into())) - } -} + /// Given the evaluations of a polynomial over the set $0, 1, \cdots, d - 1$ and a `root` not in + /// the interpolation set, computes its coefficients. + pub fn interpolate_equidistant_points(ys: &[E], root: E) -> CompressedUnivariatePoly { + // we factor out the term `(x - r)` where `r` is the root + let quotient: Vec = (0..ys.len()).map(|i| E::from(i as u32) - root).collect(); + let quotient_inv = batch_inversion("ient); + let mut ys: Vec = ys.iter().zip(quotient_inv.iter()).map(|(&y, &q)| y * q).collect(); -/// The evaluations of a univariate polynomial of degree n at 0, 1, ..., n with the evaluation at 0 -/// omitted. -/// -/// This compressed representation is useful during the sum-check protocol as the full uncompressed -/// representation can be recovered from the compressed one and the current sum-check round claim. -#[derive(Clone, Debug)] -pub struct CompressedUnivariatePolyEvals(pub(crate) SmallVec<[E; MAX_POLY_SIZE]>); + // the zeroth coefficient can be recovered immediately + let c0 = ys.remove(0); -impl CompressedUnivariatePolyEvals { - /// Gives the coefficient representation of a polynomial represented in evaluation form. - /// - /// Since the evaluation at 0 is omitted, we need to use the round claim to recover - /// the evaluation at 0 using the identity $p(0) + p(1) = claim$. - /// Now, we have that for any polynomial $p(x) = c0 + c1 * x + ... + c_{n-1} * x^{n - 1}$: - /// - /// 1. $p(0) = c0$. - /// 2. $p(x) = c0 + x * q(x) where q(x) = c1 + ... + c_{n-1} * x^{n - 2}$. - /// - /// This means that we can compute the evaluations of q at 1, ..., n - 1 using the evaluations - /// of p and thus reduce by 1 the size of the interpolation problem. - /// Once the coefficient of q are recovered, the c0 coefficient is appended to these and this - /// is precisely the coefficient representation of the original polynomial q. - /// Note that the coefficient of the linear term is removed as this coefficient can be recovered - /// from the remaining coefficients, again, using the round claim using the relation - /// $2 * c0 + c1 + ... c_{n - 1} = claim$. - pub fn to_poly(&self, round_claim: E) -> CompressedUnivariatePoly { - // construct the vector of interpolation points 1, ..., n - let n_minus_1 = self.0.len(); + // build the interpolation set + let n_minus_1 = ys.len(); let points = (1..=n_minus_1 as u32).map(E::BaseField::from).collect::>(); // construct their inverses. These will be needed for computing the evaluations - // of the q polynomial as well as for doing the interpolation on q + // of the q polynomial as well as for doing the interpolation on q where q is + // defined as $p(x) = c0 + x * q(x) where q(x) = c1 + ... + c_{n-1} * x^{n - 2}$ let points_inv = batch_inversion(&points); - // compute the zeroth coefficient - let c0 = round_claim - self.0[0]; - // compute the evaluations of q - let q_evals: Vec = self - .0 + let q_evals: Vec = ys .iter() .enumerate() .map(|(i, evals)| (*evals - c0).mul_base(points_inv[i])) @@ -118,11 +82,34 @@ impl CompressedUnivariatePolyEvals { // append c0 to the coefficients of q to get the coefficients of p. The linear term // coefficient is removed as this can be recovered from the other coefficients using // the reduced claim. - let mut coefficients = SmallVec::with_capacity(self.0.len() + 1); + let mut coefficients = SmallVec::<[E; MAX_POLY_SIZE]>::with_capacity(ys.len() + 1); coefficients.push(c0); - coefficients.extend_from_slice(&q_coefs[1..]); + coefficients.extend_from_slice(&q_coefs[..]); + + // multiply back the factor `(x - r)` + let mut p_coefficients = polynom::mul(&coefficients, &[-root, E::ONE]); + + // remove the linear factor as it can be recovered from the `claim` and the other factors + p_coefficients.remove(1); + + CompressedUnivariatePoly(p_coefficients.into()) + } +} - CompressedUnivariatePoly(coefficients) +impl Serializable for CompressedUnivariatePoly { + fn write_into(&self, target: &mut W) { + let vector: Vec = self.0.clone().into_vec(); + vector.write_into(target); + } +} + +impl Deserializable for CompressedUnivariatePoly +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + let vector: Vec = Vec::::read_from(source)?; + Ok(Self(vector.into())) } } @@ -259,22 +246,18 @@ fn test_poly_partial() { use math::fields::f64::BaseElement; let degree = 1000; - let mut points: Vec = vec![BaseElement::ZERO; degree]; - points - .iter_mut() - .enumerate() - .for_each(|(i, node)| *node = BaseElement::from(i as u32)); + // compute the claim let p: Vec = rand_utils::rand_vector(degree); - let evals = polynom::eval_many(&p, &points); - - let mut partial_evals = evals.clone(); - partial_evals.remove(0); - - let partial_poly = CompressedUnivariatePolyEvals(partial_evals.into()); + let evals = polynom::eval_many(&p, &[BaseElement::ZERO, BaseElement::ONE]); let claim = evals[0] + evals[1]; - let poly_coeff = partial_poly.to_poly(claim); + // build compressed polynomial + let mut poly_coeff = p.clone(); + poly_coeff.remove(1); + let poly_coeff = CompressedUnivariatePoly(poly_coeff.into()); + + // generate random challenge let r = rand_utils::rand_vector(1); assert_eq!(polynom::eval(&p, r[0]), poly_coeff.evaluate_using_claim(&claim, &r[0])) From 13cf6222e3fb27e2cd478de0b74b6cb04c36fc32 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 25 Sep 2024 20:17:54 +0200 Subject: [PATCH 20/44] fix parallel plain sum-check --- sumcheck/src/prover/high_degree.rs | 2 +- sumcheck/src/prover/plain.rs | 126 +++++++++++------------------ 2 files changed, 49 insertions(+), 79 deletions(-) diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index 5e42ee9ae..503cf7628 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -663,7 +663,7 @@ fn sumcheck_round( &mut numerators_one, &mut denominators_one, ); - total_evals[0] = evaluate_composition_poly( + poly_evals[0] = evaluate_composition_poly( tensored_circuits_batching, &numerators_zero, &denominators_zero, diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index 576ada231..79bf4dcc5 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -171,83 +171,53 @@ pub fn sumcheck_prove_plain_batched Date: Thu, 26 Sep 2024 07:40:26 +0200 Subject: [PATCH 21/44] wip: plain sumcheck failing --- sumcheck/src/prover/plain.rs | 9 ++++++--- sumcheck/src/verifier/mod.rs | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index 79bf4dcc5..24eaf0593 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -6,6 +6,7 @@ use alloc::vec::Vec; use crypto::{ElementHasher, RandomCoin}; +use libc_print::libc_println; use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; @@ -166,8 +167,10 @@ pub fn sumcheck_prove_plain_batched Date: Thu, 26 Sep 2024 09:11:04 +0200 Subject: [PATCH 22/44] fix: sum-check failing --- air/src/air/logup_gkr/lagrange/transition.rs | 2 +- air/src/air/logup_gkr/mod.rs | 14 ++--- prover/src/logup_gkr/mod.rs | 2 +- prover/src/logup_gkr/prover.rs | 17 +++--- sumcheck/benches/sum_check_high_degree.rs | 3 + sumcheck/benches/sum_check_plain.rs | 63 ++++++++++++++++---- sumcheck/src/prover/high_degree.rs | 13 ++-- sumcheck/src/prover/mod.rs | 18 +++--- sumcheck/src/prover/plain.rs | 9 ++- sumcheck/src/verifier/mod.rs | 17 +++--- verifier/src/logup_gkr/mod.rs | 18 +++--- winterfell/src/tests/logup_gkr_simple.rs | 15 +---- 12 files changed, 108 insertions(+), 83 deletions(-) diff --git a/air/src/air/logup_gkr/lagrange/transition.rs b/air/src/air/logup_gkr/lagrange/transition.rs index 51654b2c2..0b7e72344 100644 --- a/air/src/air/logup_gkr/lagrange/transition.rs +++ b/air/src/air/logup_gkr/lagrange/transition.rs @@ -132,7 +132,7 @@ impl LagrangeKernelTransitionConstraints { let c = lagrange_kernel_column_frame; let v = c.num_rows() - 1; - let r = lagrange_kernel_rand_elements; + let r = lagrange_kernel_rand_elements; // TODO: avoid reverse() let mut r = r.to_vec(); r.reverse(); diff --git a/air/src/air/logup_gkr/mod.rs b/air/src/air/logup_gkr/mod.rs index d1f0f4c49..da0e83124 100644 --- a/air/src/air/logup_gkr/mod.rs +++ b/air/src/air/logup_gkr/mod.rs @@ -302,16 +302,14 @@ where pub fn bind_least_significant_variable(&mut self, round_challenge: E) { for col in self.table.iter_mut() { - if col.len() > 1 { - if self.num_rows <= col.len() { - let num_evals = col.len() >> 1; - for i in 0..num_evals { - col[i] = col[i] + round_challenge * (col[i + num_evals] - col[i]); - } - col.truncate(num_evals) + if col.len() > 1 && self.num_rows <= col.len() { + let num_evals = col.len() >> 1; + for i in 0..num_evals { + col[i] = col[i] + round_challenge * (col[i + num_evals] - col[i]); } + col.truncate(num_evals) } } - self.num_rows = self.num_rows / 2; + self.num_rows /= 2; } } diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index ce36207c6..5b4e8b5c5 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -13,7 +13,7 @@ pub use prover::prove_gkr; #[cfg(feature = "concurrent")] pub use utils::{ rayon::{current_num_threads as rayon_num_threads, prelude::*}, - {chunks_mut, iter, iter_mut, batch_iter_mut}, + {batch_iter_mut, chunks_mut, iter, iter_mut}, }; // EVALUATED CIRCUIT diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index f6b7c1d35..d1a91bf67 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -201,6 +201,7 @@ fn build_mle_from_main_trace_segment( /// Proves all GKR layers except for input layer. #[instrument(skip_all)] +#[allow(clippy::type_complexity)] fn prove_intermediate_layers< E: FieldElement, C: RandomCoin, @@ -216,8 +217,8 @@ fn prove_intermediate_layers< let mut total_evaluations = Vec::with_capacity(output_layers[0].numerators.evaluations().len() * 2); - for output_layer in output_layers.into_iter() { - total_evaluations.extend_from_slice(&output_layer.numerators.evaluations()); + for output_layer in output_layers.iter() { + total_evaluations.extend_from_slice(output_layer.numerators.evaluations()); total_evaluations.extend_from_slice(output_layer.denominators.evaluations()); } transcript.reseed(H::hash_elements(&total_evaluations)); @@ -263,10 +264,12 @@ fn prove_intermediate_layers< transcript, )?; - // sample a random challenge to reduce claims - for tmp in proof.openings_claim.openings.iter() { - transcript.reseed(H::hash_elements(tmp)); + // generate the random challenge to reduce two claims into a single claim + let mut total_openings = Vec::with_capacity(proof.openings_claim.openings.len() * 4); + for opening_circuit_i in proof.openings_claim.openings.iter() { + total_openings.extend_from_slice(&opening_circuit_i); } + transcript.reseed(H::hash_elements(&total_openings)); let r_layer = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; // reduce the claims @@ -327,8 +330,8 @@ fn sum_check_prove_num_rounds_degree_3< } let proof = if inner_layers[0].numerators.num_evaluations() >= 64 { sumcheck_prove_plain_batched( - evaluation_point, &batched_claims, + evaluation_point, r_batch, inner_layers, eq, @@ -337,8 +340,8 @@ fn sum_check_prove_num_rounds_degree_3< )? } else { sumcheck_prove_plain_batched_serial( - evaluation_point, &batched_claims, + evaluation_point, r_batch, inner_layers, eq, diff --git a/sumcheck/benches/sum_check_high_degree.rs b/sumcheck/benches/sum_check_high_degree.rs index 483890579..d6bb8815e 100644 --- a/sumcheck/benches/sum_check_high_degree.rs +++ b/sumcheck/benches/sum_check_high_degree.rs @@ -44,6 +44,8 @@ fn sum_check_high_degree(c: &mut Criterion) { )| { let mls = vec![ml0, ml1, ml2, ml3, ml4]; let mut transcript = transcript; + let tensored_batching_randmoness = + rand_vector(evaluator.get_num_fractions().ilog2() as usize); sum_check_prove_higher_degree( &evaluator, @@ -53,6 +55,7 @@ fn sum_check_high_degree(c: &mut Criterion) { logup_randomness, mls, periodic_table, + &tensored_batching_randmoness, &mut transcript, ) }, diff --git a/sumcheck/benches/sum_check_plain.rs b/sumcheck/benches/sum_check_plain.rs index c7f0552bd..c031fedd2 100644 --- a/sumcheck/benches/sum_check_plain.rs +++ b/sumcheck/benches/sum_check_plain.rs @@ -11,7 +11,9 @@ use math::{fields::f64::BaseElement, FieldElement}; use rand_utils::{rand_value, rand_vector}; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; -use winter_sumcheck::{sumcheck_prove_plain_batched, EqFunction, MultiLinearPoly}; +use winter_sumcheck::{ + sumcheck_prove_plain_batched, CircuitLayerPolys, EqFunction, MultiLinearPoly, +}; const LOG_POLY_SIZE: [usize; 2] = [18, 20]; @@ -26,13 +28,31 @@ fn sum_check_plain(c: &mut Criterion) { || { let transcript = DefaultRandomCoin::>::new(&[BaseElement::ZERO; 4]); - (setup_sum_check::(log_poly_size), transcript) + (setup_sum_check::(log_poly_size, 4), transcript) }, - |((claim, r_batch, p, q, eq), transcript)| { + |( + ( + claims, + evaluation_point, + r_batch, + inner_layers, + tensored_batching_randomness, + eq, + ), + transcript, + )| { let mut eq = eq; let mut transcript = transcript; - sumcheck_prove_plain_batched(claim, r_batch, p, q, &mut eq, &mut transcript) + sumcheck_prove_plain_batched( + &claims, + &evaluation_point, + r_batch, + inner_layers, + &mut eq, + &tensored_batching_randomness, + &mut transcript, + ) }, BatchSize::SmallInput, ) @@ -44,22 +64,39 @@ fn sum_check_plain(c: &mut Criterion) { #[allow(clippy::type_complexity)] fn setup_sum_check( log_size: usize, -) -> (E, E, MultiLinearPoly, MultiLinearPoly, MultiLinearPoly) { + num_fractions: usize, +) -> (Vec, Vec, E, Vec>, Vec, MultiLinearPoly) { let n = 1 << (log_size + 1); - let p: Vec = rand_vector(n); - let q: Vec = rand_vector(n); // this will not generate the correct claim with overwhelming probability but should be fine // for benchmarking - let rand_pt = rand_vector(log_size); + let evaluation_point = rand_vector(log_size); let r_batch: E = rand_value(); - let claim: E = rand_value(); + let claims: Vec = vec![rand_value(); num_fractions]; - let p = MultiLinearPoly::from_evaluations(p); - let q = MultiLinearPoly::from_evaluations(q); - let eq = MultiLinearPoly::from_evaluations(EqFunction::new(rand_pt.into()).evaluations()); + let mut inner_layers = Vec::with_capacity(num_fractions); + for _ in 0..num_fractions { + let p: Vec = rand_vector(n); + let q: Vec = rand_vector(n); + let p = MultiLinearPoly::from_evaluations(p); + let q = MultiLinearPoly::from_evaluations(q); + let inner_layer = CircuitLayerPolys::from_mle(p, q); + inner_layers.push(inner_layer) + } + let eq = MultiLinearPoly::from_evaluations( + EqFunction::new(evaluation_point.clone().into()).evaluations(), + ); + let tensored_batching_randomness = + EqFunction::new(rand_vector::(num_fractions).into()).evaluations(); - (claim, r_batch, p, q, eq) + ( + claims, + evaluation_point, + r_batch, + inner_layers, + tensored_batching_randomness, + eq, + ) } criterion_group!(group, sum_check_plain); diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index 503cf7628..7b525b617 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -269,8 +269,7 @@ pub fn sum_check_prove_higher_degree< // run the first round of the protocol let mut round_poly_evals = sumcheck_round( - 0, - &tensored_circuits_batching, + tensored_circuits_batching, evaluator, &eq_mle, &mls, @@ -321,8 +320,7 @@ pub fn sum_check_prove_higher_degree< // run the i-th round of the protocol using the folded multi-linears for the new reduced // claim. This basically computes the s_i polynomial. let mut round_poly_evals = sumcheck_round( - i, - &tensored_circuits_batching, + tensored_circuits_batching, evaluator, &eq_mle, &mls, @@ -417,7 +415,6 @@ pub fn sum_check_prove_higher_degree< /// the current evaluation at $x$ in $\{2, ... , d_max\}$. #[allow(clippy::too_many_arguments)] fn sumcheck_round( - sum_check_round: usize, tensored_circuits_batching: &[E], evaluator: &impl LogUpGkrEvaluator::BaseField>, eq_ml: &MultiLinearPoly, @@ -506,7 +503,7 @@ fn sumcheck_round( r_sum_check, ); - // compute the evaluations at 2, ..., d_max points + // compute the evaluations at 2, ..., d_max - 1 points for i in 0..num_mls { deltas_zero[i] = evals_one_zero[i] - evals_zero_zero[i]; evals_x_zero[i] = evals_one_zero[i]; @@ -663,7 +660,7 @@ fn sumcheck_round( &mut numerators_one, &mut denominators_one, ); - poly_evals[0] = evaluate_composition_poly( + poly_evals[0] += evaluate_composition_poly( tensored_circuits_batching, &numerators_zero, &denominators_zero, @@ -673,7 +670,7 @@ fn sumcheck_round( r_sum_check, ); - // compute the evaluations at 2, ..., d_max points + // compute the evaluations at 2, ..., d_max - 1 points for i in 0..num_mls { deltas_zero[i] = evals_one_zero[i] - evals_zero_zero[i]; evals_x_zero[i] = evals_one_zero[i]; diff --git a/sumcheck/src/prover/mod.rs b/sumcheck/src/prover/mod.rs index ec1d93b1a..135af4620 100644 --- a/sumcheck/src/prover/mod.rs +++ b/sumcheck/src/prover/mod.rs @@ -4,17 +4,16 @@ // LICENSE file in the root directory of this source tree. mod high_degree; +use alloc::{fmt, vec::Vec}; use core::{fmt::Formatter, ops::Add}; -use crate::CompressedUnivariatePoly; -use alloc::{fmt, vec::Vec}; pub use high_degree::sum_check_prove_higher_degree; +use crate::CompressedUnivariatePoly; + mod plain; -use math::batch_inversion; -use math::FieldElement; -pub use plain::sumcheck_prove_plain_batched_serial; -pub use plain::sumcheck_prove_plain_batched; +use math::{batch_inversion, FieldElement}; +pub use plain::{sumcheck_prove_plain_batched, sumcheck_prove_plain_batched_serial}; mod error; pub use error::SumCheckProverError; @@ -33,7 +32,6 @@ pub struct CircuitLayerPolys { pub denominators: MultiLinearPoly, } - impl Serializable for CircuitLayerPolys where E: FieldElement, @@ -57,11 +55,9 @@ where } } - // CIRCUIT LAYER POLYS // =============================================================================================== - impl CircuitLayerPolys where E: FieldElement, @@ -88,6 +84,10 @@ where denominators: MultiLinearPoly::from_evaluations(denominators), } } + + pub fn from_mle(numerators: MultiLinearPoly, denominators: MultiLinearPoly) -> Self { + Self { numerators, denominators } + } } // CIRCUIT LAYER diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index 24eaf0593..305e3e234 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -6,7 +6,6 @@ use alloc::vec::Vec; use crypto::{ElementHasher, RandomCoin}; -use libc_print::libc_println; use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; @@ -167,10 +166,7 @@ pub fn sumcheck_prove_plain_batched( for col in periodic_columns.table() { let ml = MultiLinearPoly::from_evaluations(col.to_vec()); let num_variables = ml.num_variables(); - let point_zero = &eval_point_zero[&eval_point_zero.len() - num_variables..]; - let point_one = &eval_point_one[&eval_point_one.len() - num_variables..]; + let point_zero = &eval_point_zero[eval_point_zero.len() - num_variables..]; + let point_one = &eval_point_one[eval_point_one.len() - num_variables..]; let evaluation_zero = ml.evaluate(point_zero); evaluations_zero.push(evaluation_zero); diff --git a/verifier/src/logup_gkr/mod.rs b/verifier/src/logup_gkr/mod.rs index 57b18631d..5ceaea5c0 100644 --- a/verifier/src/logup_gkr/mod.rs +++ b/verifier/src/logup_gkr/mod.rs @@ -38,9 +38,7 @@ pub fn verify_gkr< let mut total_evaluations = Vec::with_capacity(numerators.len() * 4); let mut num_acc = E::ZERO; let mut den_acc = E::ONE; - for (_circuit_id, (nums, dens)) in - numerators.into_iter().zip(denominators.into_iter()).enumerate() - { + for (nums, dens) in numerators.iter().zip(denominators.iter()) { total_evaluations.extend_from_slice(nums.evaluations()); total_evaluations.extend_from_slice(dens.evaluations()); @@ -72,9 +70,7 @@ pub fn verify_gkr< // reduce the claim let mut reduced_claims = vec![]; - for (_circuit_id, (nums, dens)) in - numerators.into_iter().zip(denominators.into_iter()).enumerate() - { + for (nums, dens) in numerators.iter().zip(denominators.iter()) { let p0 = nums.evaluations()[0]; let p1 = nums.evaluations()[1]; let q0 = dens.evaluations()[0]; @@ -113,19 +109,21 @@ pub fn verify_gkr< )?; // generate the random challenge to reduce two claims into a single claim - for tmp in openings.iter() { - transcript.reseed(H::hash_elements(&tmp)); + let mut total_openings = Vec::with_capacity(openings.len() * 4); + for opening_circuit_i in openings.iter() { + total_openings.extend_from_slice(&opening_circuit_i); } + transcript.reseed(H::hash_elements(&total_openings)); let r_layer = transcript.draw().map_err(|_| VerifierError::FailedToGenerateChallenge)?; - for (j, ops) in openings.iter().enumerate() { + for (circuit_id, ops) in openings.iter().enumerate() { let p0 = ops[0]; let p1 = ops[1]; let q0 = ops[2]; let q1 = ops[3]; let reduced_claim = (p0 + r_layer * (p1 - p0), q0 + r_layer * (q1 - q0)); - reduced_claims[j] = reduced_claim; + reduced_claims[circuit_id] = reduced_claim; } // collect the randomness used for the current layer diff --git a/winterfell/src/tests/logup_gkr_simple.rs b/winterfell/src/tests/logup_gkr_simple.rs index cfa2743dc..32e1e2ec2 100644 --- a/winterfell/src/tests/logup_gkr_simple.rs +++ b/winterfell/src/tests/logup_gkr_simple.rs @@ -23,8 +23,7 @@ use crate::{ #[test] fn test_logup_gkr() { let aux_trace_width = 1; - let trace = LogUpGkrSimple::new(2_usize.pow(13 - ), aux_trace_width); + let trace = LogUpGkrSimple::new(2_usize.pow(13), aux_trace_width); let prover = LogUpGkrSimpleProver::new(aux_trace_width); let proof = prover.prove(trace).unwrap(); @@ -195,14 +194,7 @@ impl PlainLogUpGkrEval { let committed_3 = LogUpGkrOracle::CurrentRow(3); let committed_4 = LogUpGkrOracle::CurrentRow(4); - let oracles = vec![ - committed_0, - committed_1, - committed_2, - committed_3, - committed_4, - - ]; + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; Self { oracles, _field: PhantomData } } } @@ -238,7 +230,6 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { query[2] = frame.current()[2]; query[3] = frame.current()[3]; query[4] = frame.current()[4]; - } fn evaluate_query( @@ -264,8 +255,6 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { denominator[1] = -(rand_values[0] - E::from(query[2])); denominator[2] = -(rand_values[0] - E::from(query[3])); denominator[3] = -(rand_values[0] - E::from(query[4])); - - //numerator[0] = -E::ONE; //numerator[1] = E::ONE; From 79f5b2908225505c4e31327f87cf9c551b58b88e Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Thu, 26 Sep 2024 10:41:19 +0200 Subject: [PATCH 23/44] wip: optimize input layer generation --- prover/src/logup_gkr/mod.rs | 79 +++++++++++++++++++++++++++++++++++-- 1 file changed, 75 insertions(+), 4 deletions(-) diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 5b4e8b5c5..77b623a41 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -4,7 +4,7 @@ use air::{EvaluationFrame, GkrData, LogUpGkrEvaluator}; use math::FieldElement; use sumcheck::{CircuitLayer, CircuitLayerPolys, CircuitWire, EqFunction, SumCheckProverError}; use tracing::instrument; -use utils::{chunks, uninit_vector}; +use utils::{batch_iter_mut, chunks, uninit_vector}; use crate::Trace; @@ -13,7 +13,7 @@ pub use prover::prove_gkr; #[cfg(feature = "concurrent")] pub use utils::{ rayon::{current_num_threads as rayon_num_threads, prelude::*}, - {batch_iter_mut, chunks_mut, iter, iter_mut}, + {chunks_mut, iter, iter_mut}, }; // EVALUATED CIRCUIT @@ -99,7 +99,7 @@ impl EvaluatedCircuit { /// Evaluates the output layer at `query`, where the numerators of the output layer are treated /// as evaluations of a multilinear polynomial, and similarly for the denominators. pub fn evaluate_output_layer(&self, query: E) -> Vec<(E, E)> { - let mut res = vec![]; + let mut res = Vec::with_capacity(self.output_layers().len()); for output_layer in self.output_layers().iter() { let CircuitLayerPolys { numerators, denominators } = output_layer; @@ -113,7 +113,7 @@ impl EvaluatedCircuit { /// Generates the input layer of the circuit from the main trace columns and some randomness /// provided by the verifier. - fn generate_input_layer( + fn generate_input_layer_old( trace: &impl Trace, evaluator: &impl LogUpGkrEvaluator, log_up_randomness: &[E], @@ -156,6 +156,77 @@ impl EvaluatedCircuit { .collect() } + fn generate_input_layer( + trace: &impl Trace, + evaluator: &impl LogUpGkrEvaluator, + log_up_randomness: &[E], + ) -> Vec> { + let num_fractions = evaluator.get_num_fractions(); + let periodic_values = evaluator.build_periodic_values(trace.main_segment().num_rows()); + + let mut input_layer_wires = + unsafe { uninit_vector(trace.main_segment().num_rows() * num_fractions) }; + let num_cols = trace.main_segment().num_cols(); + let num_oracles = evaluator.get_oracles().len(); + let num_periodic_cols = periodic_values.num_columns(); + + batch_iter_mut!( + &mut input_layer_wires, + 1024, + |batch: &mut [CircuitWire], batch_offset: usize| { + let mut main_frame = EvaluationFrame::new(num_cols); + let mut query = vec![E::BaseField::ZERO; num_oracles]; + let mut periodic_values_row = vec![E::BaseField::ZERO; num_periodic_cols]; + let mut numerators = vec![E::ZERO; num_fractions]; + let mut denominators = vec![E::ZERO; num_fractions]; + + let row_offset = batch_offset / num_fractions; + let batch_size = batch.len(); + let num_rows_per_batch = batch_size / num_fractions; + + for i in + (0..trace.main_segment().num_rows()).skip(row_offset).take(num_rows_per_batch) + { + trace.read_main_frame(i, &mut main_frame); + periodic_values.fill_periodic_values_at(i, &mut periodic_values_row); + evaluator.build_query(&main_frame, &mut query); + + evaluator.evaluate_query( + &query, + &periodic_values_row, + log_up_randomness, + &mut numerators, + &mut denominators, + ); + + let n = (i - row_offset) * num_fractions; + for ((wire, numerator), denominator) in batch[n..n + num_fractions] + .iter_mut() + .zip(numerators.iter()) + .zip(denominators.iter()) + { + *wire = CircuitWire::new(*numerator, *denominator); + } + } + } + ); + + let mut result: Vec>> = + vec![unsafe { uninit_vector(trace.main_segment().num_rows()) }; num_fractions]; + + input_layer_wires.chunks(num_fractions).enumerate().for_each(|(row, chunk)| { + chunk + .iter() + .zip(result.iter_mut()) + .for_each(|(value, destination)| destination[row] = *value); + }); + + result + .iter() + .map(|input_layer| CircuitLayer::new(input_layer.to_vec())) + .collect() + } + /// Computes the subsequent layer of the circuit from a given layer. fn compute_next_layer(prev_layers: &[CircuitLayer]) -> Vec> { let mut next_layers = Vec::with_capacity(prev_layers.len() / 2); From e6ade309c88cf27c038e71a47aadc09090a498ef Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 25 Sep 2024 17:52:15 +0200 Subject: [PATCH 24/44] added example for logup-gkr --- examples/src/lib.rs | 8 +- examples/src/logup_gkr/air.rs | 189 +++++++++++++++++++++++++++++++ examples/src/logup_gkr/mod.rs | 124 ++++++++++++++++++++ examples/src/logup_gkr/prover.rs | 169 +++++++++++++++++++++++++++ examples/src/logup_gkr/tests.rs | 35 ++++++ examples/src/main.rs | 3 +- winterfell/src/lib.rs | 2 +- 7 files changed, 527 insertions(+), 3 deletions(-) create mode 100644 examples/src/logup_gkr/air.rs create mode 100644 examples/src/logup_gkr/mod.rs create mode 100644 examples/src/logup_gkr/prover.rs create mode 100644 examples/src/logup_gkr/tests.rs diff --git a/examples/src/lib.rs b/examples/src/lib.rs index 33f733d7c..5bfa3e217 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -9,7 +9,7 @@ use winterfell::{ math::fields::f128::BaseElement, FieldExtension, Proof, ProofOptions, VerifierError, }; - +pub mod logup_gkr; pub mod fibonacci; #[cfg(feature = "std")] pub mod lamport; @@ -198,6 +198,12 @@ pub enum ExampleType { #[structopt(short = "n", default_value = "3")] num_signers: usize, }, + /// LogUp-GKR + LogUpGkr { + /// Length of the trace; must be a power of two + #[structopt(short = "n", default_value = "65536")] + trace_length: usize, + }, } /// Defines a set of hash functions available for the provided examples. Some examples may not diff --git a/examples/src/logup_gkr/air.rs b/examples/src/logup_gkr/air.rs new file mode 100644 index 000000000..9a1a99b5c --- /dev/null +++ b/examples/src/logup_gkr/air.rs @@ -0,0 +1,189 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::marker::PhantomData; + +use winterfell::{ + math::{fields::f64::BaseElement, ExtensionOf, FieldElement, StarkField}, + Air, AirContext, Assertion, AuxRandElements, EvaluationFrame, LogUpGkrEvaluator, + LogUpGkrOracle, TraceInfo, TransitionConstraintDegree, +}; + +use super::ProofOptions; + +pub(crate) struct LogUpGkrSimpleAir { + context: AirContext, +} + +impl Air for LogUpGkrSimpleAir { + type BaseField = BaseElement; + type PublicInputs = (); + + fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { + Self { + context: AirContext::new_multi_segment( + trace_info, + _pub_inputs, + vec![TransitionConstraintDegree::new(1)], + vec![], + 1, + 0, + options, + ), + } + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn evaluate_transition>( + &self, + frame: &EvaluationFrame, + _periodic_values: &[E], + result: &mut [E], + ) { + let current = frame.current()[0]; + let next = frame.next()[0]; + + // increments by 1 + result[0] = next - current - E::ONE; + } + + fn get_assertions(&self) -> Vec> { + vec![Assertion::single(0, 0, BaseElement::ZERO)] + } + + fn evaluate_aux_transition( + &self, + _main_frame: &EvaluationFrame, + _aux_frame: &EvaluationFrame, + _periodic_values: &[F], + _aux_rand_elements: &AuxRandElements, + _result: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + // do nothing + } + + fn get_aux_assertions>( + &self, + _aux_rand_elements: &AuxRandElements, + ) -> Vec> { + vec![] + } + + fn get_logup_gkr_evaluator( + &self, + ) -> impl LogUpGkrEvaluator + { + PlainLogUpGkrEval::new() + } +} + +#[derive(Clone, Default)] +pub struct PlainLogUpGkrEval { + oracles: Vec, + _field: PhantomData, +} + +impl PlainLogUpGkrEval { + pub fn new() -> Self { + let committed_0 = LogUpGkrOracle::CurrentRow(0); + let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_2 = LogUpGkrOracle::CurrentRow(2); + let committed_3 = LogUpGkrOracle::CurrentRow(3); + let committed_4 = LogUpGkrOracle::CurrentRow(4); + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; + Self { oracles, _field: PhantomData } + } +} + +impl LogUpGkrEvaluator for PlainLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + fn get_num_rand_values(&self) -> usize { + 1 + } + + fn get_num_fractions(&self) -> usize { + 16 + } + + fn max_degree(&self) -> usize { + 10 + } + + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) + where + E: FieldElement, + { + query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f) + } + + fn evaluate_query( + &self, + query: &[F], + _periodic_values: &[F], + rand_values: &[E], + numerator: &mut [E], + denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + assert_eq!(numerator.len(), 16); + assert_eq!(denominator.len(), 16); + assert_eq!(query.len(), 5); + numerator[0] = E::from(query[1]); + numerator[1] = E::ONE; + numerator[2] = E::ONE; + numerator[3] = E::ONE; + numerator[4] = E::from(query[1]); + numerator[5] = E::ONE; + numerator[6] = E::ONE; + numerator[7] = E::ONE; + numerator[8] = E::from(query[1]); + numerator[9] = E::ONE; + numerator[10] = E::ONE; + numerator[11] = E::ONE; + numerator[12] = E::from(query[1]); + numerator[13] = E::ONE; + numerator[14] = E::ONE; + numerator[15] = E::ONE; + + denominator[0] = rand_values[0] - E::from(query[0]); + denominator[1] = -(rand_values[0] - E::from(query[2])); + denominator[2] = -(rand_values[0] - E::from(query[3])); + denominator[3] = -(rand_values[0] - E::from(query[4])); + denominator[4] = rand_values[0] - E::from(query[0]); + denominator[5] = -(rand_values[0] - E::from(query[2])); + denominator[6] = -(rand_values[0] - E::from(query[3])); + denominator[7] = -(rand_values[0] - E::from(query[4])); + denominator[8] = rand_values[0] - E::from(query[0]); + denominator[9] = -(rand_values[0] - E::from(query[2])); + denominator[10] = -(rand_values[0] - E::from(query[3])); + denominator[11] = -(rand_values[0] - E::from(query[4])); + denominator[12] = rand_values[0] - E::from(query[0]); + denominator[13] = -(rand_values[0] - E::from(query[2])); + denominator[14] = -(rand_values[0] - E::from(query[3])); + denominator[15] = -(rand_values[0] - E::from(query[4])); + } + + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + E::ZERO + } +} diff --git a/examples/src/logup_gkr/mod.rs b/examples/src/logup_gkr/mod.rs new file mode 100644 index 000000000..f16054e8f --- /dev/null +++ b/examples/src/logup_gkr/mod.rs @@ -0,0 +1,124 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use core::marker::PhantomData; + +use winterfell::{ + crypto::{DefaultRandomCoin, ElementHasher, MerkleTree}, + math::fields::f64::BaseElement, + Proof, ProofOptions, Prover, VerifierError, +}; + +use crate::{Example, ExampleOptions, HashFunction}; + +mod air; +use air::LogUpGkrSimpleAir; + +mod prover; +use prover::LogUpGkrSimpleProver; + +#[cfg(test)] +mod tests; + +// CONSTANTS AND TYPES +// ================================================================================================ + +const AUX_TRACE_WIDTH: usize = 2; + +type Blake3_192 = winterfell::crypto::hashers::Blake3_192; +type Blake3_256 = winterfell::crypto::hashers::Blake3_256; +type Sha3_256 = winterfell::crypto::hashers::Sha3_256; +type Rp64_256 = winterfell::crypto::hashers::Rp64_256; +type RpJive64_256 = winterfell::crypto::hashers::RpJive64_256; + +// FIBONACCI EXAMPLE +// ================================================================================================ + +pub fn get_example( + options: &ExampleOptions, + trace_length: usize, +) -> Result, String> { + let (options, hash_fn) = options.to_proof_options(28, 8); + + match hash_fn { + HashFunction::Blake3_192 => { + Ok(Box::new(LogUpGkrSimple::::new(trace_length, AUX_TRACE_WIDTH, options))) + }, + HashFunction::Blake3_256 => { + Ok(Box::new(LogUpGkrSimple::::new(trace_length, AUX_TRACE_WIDTH, options))) + }, + HashFunction::Sha3_256 => { + Ok(Box::new(LogUpGkrSimple::::new(trace_length, AUX_TRACE_WIDTH, options))) + }, + HashFunction::Rp64_256 => { + Ok(Box::new(LogUpGkrSimple::::new(trace_length, AUX_TRACE_WIDTH, options))) + }, + HashFunction::RpJive64_256 => Ok(Box::new(LogUpGkrSimple::::new( + trace_length, + AUX_TRACE_WIDTH, + options, + ))), + } +} + +#[derive(Clone, Debug)] +struct LogUpGkrSimple> { + trace_len: usize, + aux_segment_width: usize, + options: ProofOptions, + _hasher: PhantomData, +} + +impl> LogUpGkrSimple { + fn new(trace_len: usize, aux_segment_width: usize, options: ProofOptions) -> Self { + assert!(trace_len < u32::MAX.try_into().unwrap()); + + Self { + trace_len, + aux_segment_width, + options, + _hasher: PhantomData, + } + } +} + +// EXAMPLE IMPLEMENTATION +// ================================================================================================ + +impl Example for LogUpGkrSimple +where + H: ElementHasher + Sync + Send, +{ + fn prove(&self) -> Proof { + // create a prover + let prover = LogUpGkrSimpleProver::::new(AUX_TRACE_WIDTH, self.options.clone()); + + let trace = prover.build_trace(self.trace_len, self.aux_segment_width); + + // generate the proof + prover.prove(trace).unwrap() + } + + fn verify(&self, proof: Proof) -> Result<(), VerifierError> { + let acceptable_options = + winterfell::AcceptableOptions::OptionSet(vec![proof.options().clone()]); + + winterfell::verify::, MerkleTree>( + proof, + (), + &acceptable_options, + ) + } + + fn verify_with_wrong_inputs(&self, proof: Proof) -> Result<(), VerifierError> { + let acceptable_options = + winterfell::AcceptableOptions::OptionSet(vec![proof.options().clone()]); + winterfell::verify::, MerkleTree>( + proof, + (), + &acceptable_options, + ) + } +} diff --git a/examples/src/logup_gkr/prover.rs b/examples/src/logup_gkr/prover.rs new file mode 100644 index 000000000..5a4affe31 --- /dev/null +++ b/examples/src/logup_gkr/prover.rs @@ -0,0 +1,169 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use winterfell::{ + crypto::MerkleTree, math::FieldElement, matrix::ColMatrix, Air, AuxRandElements, + ConstraintCompositionCoefficients, DefaultTraceLde, EvaluationFrame, + LogUpGkrConstraintEvaluator, StarkDomain, Trace, TraceInfo, TracePolyTable, +}; + +use super::{ + air::LogUpGkrSimpleAir, BaseElement, DefaultRandomCoin, ElementHasher, PhantomData, + ProofOptions, Prover, +}; + +pub(crate) struct LogUpGkrSimpleProver + Sync + Send> { + aux_trace_width: usize, + options: ProofOptions, + _hasher: PhantomData, +} + +impl + Sync + Send> LogUpGkrSimpleProver { + pub(crate) fn new(aux_trace_width: usize, options: ProofOptions) -> Self { + Self { + aux_trace_width, + options, + _hasher: PhantomData, + } + } + + /// Builds an execution trace for computing a Fibonacci sequence of the specified length such + /// that each row advances the sequence by 2 terms. + pub fn build_trace(&self, trace_len: usize, aux_segment_width: usize) -> LogUpGkrSimpleTrace { + LogUpGkrSimpleTrace::new(trace_len, aux_segment_width) + } +} + +impl + Sync + Send> Prover for LogUpGkrSimpleProver { + type BaseField = BaseElement; + type Air = LogUpGkrSimpleAir; + type Trace = LogUpGkrSimpleTrace; + type HashFn = H; + type VC = MerkleTree; + type RandomCoin = DefaultRandomCoin; + type TraceLde> = + DefaultTraceLde; + type ConstraintEvaluator<'a, E: FieldElement> = + LogUpGkrConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; + + fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { + } + + fn options(&self) -> &ProofOptions { + &self.options + } + + fn new_trace_lde( + &self, + trace_info: &TraceInfo, + main_trace: &ColMatrix, + domain: &StarkDomain, + ) -> (Self::TraceLde, TracePolyTable) + where + E: FieldElement, + { + DefaultTraceLde::new(trace_info, main_trace, domain) + } + + fn new_evaluator<'a, E>( + &self, + air: &'a Self::Air, + aux_rand_elements: Option>, + composition_coefficients: ConstraintCompositionCoefficients, + ) -> Self::ConstraintEvaluator<'a, E> + where + E: FieldElement, + { + LogUpGkrConstraintEvaluator::new(air, aux_rand_elements.unwrap(), composition_coefficients) + } + + fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix + where + E: FieldElement, + { + let main_trace = main_trace.main_segment(); + + let mut columns = Vec::new(); + + let rand_summed = E::from(777_u32); + for _ in 0..self.aux_trace_width { + // building a dummy auxiliary column + let column = main_trace + .get_column(0) + .iter() + .map(|row_val| rand_summed.mul_base(*row_val)) + .collect(); + + columns.push(column); + } + + ColMatrix::new(columns) + } +} + +#[derive(Clone, Debug)] +pub(crate) struct LogUpGkrSimpleTrace { + // dummy main trace + main_trace: ColMatrix, + info: TraceInfo, +} + +impl LogUpGkrSimpleTrace { + fn new(trace_len: usize, aux_segment_width: usize) -> Self { + assert!(trace_len < u32::MAX.try_into().unwrap()); + + let table: Vec = + (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); + let mut multiplicity: Vec = + (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + multiplicity[0] = BaseElement::new(3 * trace_len as u64 - 3 * 4); + multiplicity[1] = BaseElement::new(3 * 4); + + let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..4 { + values_0[i + 4] = BaseElement::ONE; + } + + let mut values_1: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..4 { + values_1[i + 4] = BaseElement::ONE; + } + + let mut values_2: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..4 { + values_2[i + 4] = BaseElement::ONE; + } + + Self { + main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), + info: TraceInfo::new_multi_segment(5, aux_segment_width, 1, trace_len, vec![], true), + } + } + + fn len(&self) -> usize { + self.main_trace.num_rows() + } +} + +impl Trace for LogUpGkrSimpleTrace { + type BaseField = BaseElement; + + fn info(&self) -> &TraceInfo { + &self.info + } + + fn main_segment(&self) -> &ColMatrix { + &self.main_trace + } + + fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame) { + let next_row_idx = row_idx + 1; + self.main_trace.read_row_into(row_idx, frame.current_mut()); + self.main_trace.read_row_into(next_row_idx % self.len(), frame.next_mut()); + } +} diff --git a/examples/src/logup_gkr/tests.rs b/examples/src/logup_gkr/tests.rs new file mode 100644 index 000000000..cff9c9577 --- /dev/null +++ b/examples/src/logup_gkr/tests.rs @@ -0,0 +1,35 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use winterfell::{FieldExtension, ProofOptions}; + +use super::{Rp64_256, AUX_TRACE_WIDTH}; + +#[test] +fn logup_gkr_small_test_basic_proof_verification() { + let logup_gkr = Box::new(super::LogUpGkrSimple::::new(128, AUX_TRACE_WIDTH, build_options(false))); + crate::tests::test_basic_proof_verification(logup_gkr); +} + +#[test] +fn logup_gkr_small_test_basic_proof_verification_extension() { + let logup_gkr = Box::new(super::LogUpGkrSimple::::new(128, AUX_TRACE_WIDTH, build_options(true))); + crate::tests::test_basic_proof_verification(logup_gkr); +} + +#[test] +fn logup_gkr_small_test_basic_proof_verification_fail() { + let logup_gkr = Box::new(super::LogUpGkrSimple::::new(128, AUX_TRACE_WIDTH, build_options(false))); + crate::tests::test_basic_proof_verification_fail(logup_gkr); +} + +fn build_options(use_extension_field: bool) -> ProofOptions { + let extension = if use_extension_field { + FieldExtension::Quadratic + } else { + FieldExtension::None + }; + ProofOptions::new(28, 8, 0, extension, 4, 31) +} diff --git a/examples/src/main.rs b/examples/src/main.rs index 36c1e0c0d..79a7a0e5f 100644 --- a/examples/src/main.rs +++ b/examples/src/main.rs @@ -5,7 +5,7 @@ use std::time::Instant; -use examples::{fibonacci, rescue, vdf, ExampleOptions, ExampleType}; +use examples::{fibonacci, logup_gkr, rescue, vdf, ExampleOptions, ExampleType}; #[cfg(feature = "std")] use examples::{lamport, merkle, rescue_raps}; use structopt::StructOpt; @@ -82,6 +82,7 @@ fn main() { ExampleType::LamportT { num_signers } => { lamport::threshold::get_example(&options, num_signers) }, + ExampleType::LogUpGkr { trace_length } => logup_gkr::get_example(&options, trace_length), } .expect("The example failed to initialize."); diff --git a/winterfell/src/lib.rs b/winterfell/src/lib.rs index 621796864..291e0f27e 100644 --- a/winterfell/src/lib.rs +++ b/winterfell/src/lib.rs @@ -590,7 +590,7 @@ #[cfg(test)] extern crate std; -pub use air::{AuxRandElements, LogUpGkrEvaluator}; +pub use air::{AuxRandElements, LogUpGkrEvaluator, LogUpGkrOracle}; pub use prover::{ crypto, iterators, math, matrix, Air, AirContext, Assertion, AuxTraceWithMetadata, BoundaryConstraint, BoundaryConstraintGroup, CompositionPolyTrace, From 9efa494f006fc7d165fb4ddb61777ba3bb1318d3 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Thu, 26 Sep 2024 17:35:22 +0200 Subject: [PATCH 25/44] increase logup-gkr example --- examples/src/logup_gkr/air.rs | 40 +++++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/examples/src/logup_gkr/air.rs b/examples/src/logup_gkr/air.rs index 9a1a99b5c..cea162be3 100644 --- a/examples/src/logup_gkr/air.rs +++ b/examples/src/logup_gkr/air.rs @@ -116,8 +116,8 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { 1 } - fn get_num_fractions(&self) -> usize { - 16 + fn get_num_fractions(&self) -> usize { + 32 } fn max_degree(&self) -> usize { @@ -142,8 +142,8 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { F: FieldElement, E: FieldElement + ExtensionOf, { - assert_eq!(numerator.len(), 16); - assert_eq!(denominator.len(), 16); + assert_eq!(numerator.len(), 32); + assert_eq!(denominator.len(), 32); assert_eq!(query.len(), 5); numerator[0] = E::from(query[1]); numerator[1] = E::ONE; @@ -161,6 +161,22 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { numerator[13] = E::ONE; numerator[14] = E::ONE; numerator[15] = E::ONE; + numerator[16] = E::from(query[1]); + numerator[17] = E::ONE; + numerator[18] = E::ONE; + numerator[19] = E::ONE; + numerator[20] = E::from(query[1]); + numerator[21] = E::ONE; + numerator[22] = E::ONE; + numerator[23] = E::ONE; + numerator[24] = E::from(query[1]); + numerator[25] = E::ONE; + numerator[26] = E::ONE; + numerator[27] = E::ONE; + numerator[28] = E::from(query[1]); + numerator[29] = E::ONE; + numerator[30] = E::ONE; + numerator[31] = E::ONE; denominator[0] = rand_values[0] - E::from(query[0]); denominator[1] = -(rand_values[0] - E::from(query[2])); @@ -178,6 +194,22 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { denominator[13] = -(rand_values[0] - E::from(query[2])); denominator[14] = -(rand_values[0] - E::from(query[3])); denominator[15] = -(rand_values[0] - E::from(query[4])); + denominator[16] = rand_values[0] - E::from(query[0]); + denominator[17] = -(rand_values[0] - E::from(query[2])); + denominator[18] = -(rand_values[0] - E::from(query[3])); + denominator[19] = -(rand_values[0] - E::from(query[4])); + denominator[20] = rand_values[0] - E::from(query[0]); + denominator[21] = -(rand_values[0] - E::from(query[2])); + denominator[22] = -(rand_values[0] - E::from(query[3])); + denominator[23] = -(rand_values[0] - E::from(query[4])); + denominator[24] = rand_values[0] - E::from(query[0]); + denominator[25] = -(rand_values[0] - E::from(query[2])); + denominator[26] = -(rand_values[0] - E::from(query[3])); + denominator[27] = -(rand_values[0] - E::from(query[4])); + denominator[28] = rand_values[0] - E::from(query[0]); + denominator[29] = -(rand_values[0] - E::from(query[2])); + denominator[30] = -(rand_values[0] - E::from(query[3])); + denominator[31] = -(rand_values[0] - E::from(query[4])); } fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E From d4da930ec02a47b39286b5ac1d7a0cc869c52b2a Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Thu, 26 Sep 2024 18:45:23 +0200 Subject: [PATCH 26/44] fix: clippy --- prover/src/logup_gkr/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 77b623a41..5fd834323 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -113,6 +113,7 @@ impl EvaluatedCircuit { /// Generates the input layer of the circuit from the main trace columns and some randomness /// provided by the verifier. + #[allow(dead_code)] fn generate_input_layer_old( trace: &impl Trace, evaluator: &impl LogUpGkrEvaluator, From 7c1a4c1cc3741e2b31d74b04069d237260736279 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 30 Sep 2024 15:18:36 +0200 Subject: [PATCH 27/44] wip --- prover/src/logup_gkr/mod.rs | 72 +++++++++++++++++++++--------- sumcheck/src/prover/high_degree.rs | 2 + sumcheck/src/prover/mod.rs | 4 +- 3 files changed, 55 insertions(+), 23 deletions(-) diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 5fd834323..060a816a4 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -1,4 +1,5 @@ use alloc::vec::Vec; +use core::num; use air::{EvaluationFrame, GkrData, LogUpGkrEvaluator}; use math::FieldElement; @@ -69,9 +70,12 @@ impl EvaluatedCircuit { ) -> Result { let mut layer_polys = Vec::new(); - let mut current_layer = + let current_layer = Self::generate_input_layer(main_trace_columns, evaluator, log_up_randomness); - while current_layer[0].num_wires() > 1 { + + let mut current_layer = + Self::generate_second_layer(current_layer, evaluator.get_num_fractions()); + while current_layer[0].len() > 1 { let next_layer = Self::compute_next_layer(¤t_layer); layer_polys.push(CircuitLayerPolys::from_circuit_layer(¤t_layer)); @@ -161,7 +165,7 @@ impl EvaluatedCircuit { trace: &impl Trace, evaluator: &impl LogUpGkrEvaluator, log_up_randomness: &[E], - ) -> Vec> { + ) -> Vec> { let num_fractions = evaluator.get_num_fractions(); let periodic_values = evaluator.build_periodic_values(trace.main_segment().num_rows()); @@ -211,28 +215,14 @@ impl EvaluatedCircuit { } } ); - - let mut result: Vec>> = - vec![unsafe { uninit_vector(trace.main_segment().num_rows()) }; num_fractions]; - - input_layer_wires.chunks(num_fractions).enumerate().for_each(|(row, chunk)| { - chunk - .iter() - .zip(result.iter_mut()) - .for_each(|(value, destination)| destination[row] = *value); - }); - - result - .iter() - .map(|input_layer| CircuitLayer::new(input_layer.to_vec())) - .collect() + input_layer_wires } /// Computes the subsequent layer of the circuit from a given layer. - fn compute_next_layer(prev_layers: &[CircuitLayer]) -> Vec> { + fn compute_next_layer(prev_layers: &[Vec>]) -> Vec>> { let mut next_layers = Vec::with_capacity(prev_layers.len() / 2); for prev_layer in prev_layers.iter() { - let next_layer_wires = chunks!(prev_layer.wires(), 2) + let next_layer_wires = chunks!(prev_layer, 2) .map(|input_wires| { let left_input_wire = input_wires[0]; let right_input_wire = input_wires[1]; @@ -242,10 +232,50 @@ impl EvaluatedCircuit { }) .collect(); - next_layers.push(CircuitLayer::new(next_layer_wires)) + next_layers.push((next_layer_wires)) } next_layers } + + fn generate_second_layer( + current_layer: Vec>, + num_fractions: usize, + ) -> Vec>> { + let mut result: Vec>> = + vec![ + unsafe { uninit_vector(current_layer.len() / (num_fractions * 2)) }; + num_fractions + ]; + + result.par_iter_mut().enumerate().for_each(|(circuit_idx, circuit)| { + current_layer.chunks(2 * num_fractions).enumerate().for_each( + |(row, fractions_at_row)| { + let left = fractions_at_row[circuit_idx]; + let right = fractions_at_row[circuit_idx + num_fractions]; + circuit[row] = left + right; + }, + ); + }); + //input_layer_wires.chunks(num_fractions).enumerate().for_each(|(row, chunk)| { + //chunk + //.iter() + //.zip(result.iter_mut()) + //.for_each(|(value, destination)| destination[row] = *value); + //}); + + //result + //.iter() + //.map(|input_layer| CircuitLayer::new(input_layer.to_vec())) + //.collect() + + //current_layer.par_chunks(2 * num_fractions).for_each(|chunk| { + //let (even, odd) = chunk.split_at(num_fractions); + //let res = even.iter().zip(odd.iter()).map(|(&e, &o)| e + o).collect(); + //}); + + //todo!() + result + } } /// Represents a claim to be proven by a subsequent call to the sum-check protocol. diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index 7b525b617..9c29b5c42 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -301,6 +301,8 @@ pub fn sum_check_prove_higher_degree< let round_challenge = coin.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + assert_eq!(evaluation_point.len(), 11); + assert_eq!(mls[0].num_variables(), 13); // update `scaling_up_factor` alpha_i = evaluation_point[evaluation_point.len() + 1 - mls[0].num_variables()]; scaling_up_factor *= diff --git a/sumcheck/src/prover/mod.rs b/sumcheck/src/prover/mod.rs index 135af4620..98f18b589 100644 --- a/sumcheck/src/prover/mod.rs +++ b/sumcheck/src/prover/mod.rs @@ -62,10 +62,10 @@ impl CircuitLayerPolys where E: FieldElement, { - pub fn from_circuit_layer(layers: &[CircuitLayer]) -> Vec { + pub fn from_circuit_layer(layers: &[Vec>]) -> Vec { let mut result = vec![]; for layer in layers { - result.push(Self::from_wires(layer.wires.clone())) + result.push(Self::from_wires(layer.clone())) } result } From fce45c87de1aa7db9e25f18d5b59bdb52cb14ef7 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 30 Sep 2024 15:33:02 +0200 Subject: [PATCH 28/44] wip --- prover/src/logup_gkr/prover.rs | 2 +- sumcheck/src/prover/high_degree.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index d1a91bf67..5102dc75e 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -250,7 +250,7 @@ fn prove_intermediate_layers< // loop over all inner layers in order to iteratively reduce a layer in terms of its successor // layer. Note that we don't include the input layer, since its predecessor layer will be // reduced in terms of the input layer separately in `prove_final_circuit_layer`. - for inner_layer in circuit.layers().into_iter().skip(1).rev().skip(1) { + for inner_layer in circuit.layers().into_iter().skip(0).rev().skip(1) { // construct the Lagrange kernel evaluated at the previous GKR round randomness let mut eq_mle = EqFunction::ml_at(evaluation_point.clone().into()); diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index 9c29b5c42..a0e8efe84 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -301,8 +301,8 @@ pub fn sum_check_prove_higher_degree< let round_challenge = coin.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; - assert_eq!(evaluation_point.len(), 11); - assert_eq!(mls[0].num_variables(), 13); + //assert_eq!(evaluation_point.len(), 12); + //assert_eq!(mls[0].num_variables(), 13); // update `scaling_up_factor` alpha_i = evaluation_point[evaluation_point.len() + 1 - mls[0].num_variables()]; scaling_up_factor *= From c2bdd11ad3f3528eff0e67b2cb16984873b0d461 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 30 Sep 2024 16:28:34 +0200 Subject: [PATCH 29/44] reduce size of example --- examples/src/logup_gkr/air.rs | 70 +++++++++++++++++------------------ 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/examples/src/logup_gkr/air.rs b/examples/src/logup_gkr/air.rs index cea162be3..ff733b124 100644 --- a/examples/src/logup_gkr/air.rs +++ b/examples/src/logup_gkr/air.rs @@ -117,7 +117,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { } fn get_num_fractions(&self) -> usize { - 32 + 16 } fn max_degree(&self) -> usize { @@ -142,8 +142,8 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { F: FieldElement, E: FieldElement + ExtensionOf, { - assert_eq!(numerator.len(), 32); - assert_eq!(denominator.len(), 32); + assert_eq!(numerator.len(), self.get_num_fractions()); + assert_eq!(denominator.len(), self.get_num_fractions()); assert_eq!(query.len(), 5); numerator[0] = E::from(query[1]); numerator[1] = E::ONE; @@ -161,22 +161,22 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { numerator[13] = E::ONE; numerator[14] = E::ONE; numerator[15] = E::ONE; - numerator[16] = E::from(query[1]); - numerator[17] = E::ONE; - numerator[18] = E::ONE; - numerator[19] = E::ONE; - numerator[20] = E::from(query[1]); - numerator[21] = E::ONE; - numerator[22] = E::ONE; - numerator[23] = E::ONE; - numerator[24] = E::from(query[1]); - numerator[25] = E::ONE; - numerator[26] = E::ONE; - numerator[27] = E::ONE; - numerator[28] = E::from(query[1]); - numerator[29] = E::ONE; - numerator[30] = E::ONE; - numerator[31] = E::ONE; + //numerator[16] = E::from(query[1]); + //numerator[17] = E::ONE; + //numerator[18] = E::ONE; + //numerator[19] = E::ONE; + //numerator[20] = E::from(query[1]); + //numerator[21] = E::ONE; + //numerator[22] = E::ONE; + //numerator[23] = E::ONE; + //numerator[24] = E::from(query[1]); + //numerator[25] = E::ONE; + //numerator[26] = E::ONE; + //numerator[27] = E::ONE; + //numerator[28] = E::from(query[1]); + //numerator[29] = E::ONE; + //numerator[30] = E::ONE; + //numerator[31] = E::ONE; denominator[0] = rand_values[0] - E::from(query[0]); denominator[1] = -(rand_values[0] - E::from(query[2])); @@ -194,22 +194,22 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { denominator[13] = -(rand_values[0] - E::from(query[2])); denominator[14] = -(rand_values[0] - E::from(query[3])); denominator[15] = -(rand_values[0] - E::from(query[4])); - denominator[16] = rand_values[0] - E::from(query[0]); - denominator[17] = -(rand_values[0] - E::from(query[2])); - denominator[18] = -(rand_values[0] - E::from(query[3])); - denominator[19] = -(rand_values[0] - E::from(query[4])); - denominator[20] = rand_values[0] - E::from(query[0]); - denominator[21] = -(rand_values[0] - E::from(query[2])); - denominator[22] = -(rand_values[0] - E::from(query[3])); - denominator[23] = -(rand_values[0] - E::from(query[4])); - denominator[24] = rand_values[0] - E::from(query[0]); - denominator[25] = -(rand_values[0] - E::from(query[2])); - denominator[26] = -(rand_values[0] - E::from(query[3])); - denominator[27] = -(rand_values[0] - E::from(query[4])); - denominator[28] = rand_values[0] - E::from(query[0]); - denominator[29] = -(rand_values[0] - E::from(query[2])); - denominator[30] = -(rand_values[0] - E::from(query[3])); - denominator[31] = -(rand_values[0] - E::from(query[4])); + //denominator[16] = rand_values[0] - E::from(query[0]); + //denominator[17] = -(rand_values[0] - E::from(query[2])); + //denominator[18] = -(rand_values[0] - E::from(query[3])); + //denominator[19] = -(rand_values[0] - E::from(query[4])); + //denominator[20] = rand_values[0] - E::from(query[0]); + //denominator[21] = -(rand_values[0] - E::from(query[2])); + //denominator[22] = -(rand_values[0] - E::from(query[3])); + //denominator[23] = -(rand_values[0] - E::from(query[4])); + //denominator[24] = rand_values[0] - E::from(query[0]); + //denominator[25] = -(rand_values[0] - E::from(query[2])); + //denominator[26] = -(rand_values[0] - E::from(query[3])); + //denominator[27] = -(rand_values[0] - E::from(query[4])); + //denominator[28] = rand_values[0] - E::from(query[0]); + //denominator[29] = -(rand_values[0] - E::from(query[2])); + //denominator[30] = -(rand_values[0] - E::from(query[3])); + //denominator[31] = -(rand_values[0] - E::from(query[4])); } fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E From 6355ea542bbba35eb786d5d7042f69d5cae30ba1 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 30 Sep 2024 16:35:04 +0200 Subject: [PATCH 30/44] feat: parallel input layer construction --- examples/src/lib.rs | 2 +- examples/src/logup_gkr/air.rs | 2 +- examples/src/logup_gkr/mod.rs | 32 +++++++++++++++++----------- examples/src/logup_gkr/tests.rs | 19 ++++++++++++++--- prover/src/logup_gkr/mod.rs | 34 ++++++++++++------------------ prover/src/logup_gkr/prover.rs | 4 ++-- sumcheck/src/prover/high_degree.rs | 2 -- verifier/src/logup_gkr/mod.rs | 2 +- 8 files changed, 55 insertions(+), 42 deletions(-) diff --git a/examples/src/lib.rs b/examples/src/lib.rs index 5bfa3e217..bf95e9f8a 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -9,10 +9,10 @@ use winterfell::{ math::fields::f128::BaseElement, FieldExtension, Proof, ProofOptions, VerifierError, }; -pub mod logup_gkr; pub mod fibonacci; #[cfg(feature = "std")] pub mod lamport; +pub mod logup_gkr; #[cfg(feature = "std")] pub mod merkle; pub mod rescue; diff --git a/examples/src/logup_gkr/air.rs b/examples/src/logup_gkr/air.rs index ff733b124..2db638db2 100644 --- a/examples/src/logup_gkr/air.rs +++ b/examples/src/logup_gkr/air.rs @@ -116,7 +116,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { 1 } - fn get_num_fractions(&self) -> usize { + fn get_num_fractions(&self) -> usize { 16 } diff --git a/examples/src/logup_gkr/mod.rs b/examples/src/logup_gkr/mod.rs index f16054e8f..d56b67a98 100644 --- a/examples/src/logup_gkr/mod.rs +++ b/examples/src/logup_gkr/mod.rs @@ -43,18 +43,26 @@ pub fn get_example( let (options, hash_fn) = options.to_proof_options(28, 8); match hash_fn { - HashFunction::Blake3_192 => { - Ok(Box::new(LogUpGkrSimple::::new(trace_length, AUX_TRACE_WIDTH, options))) - }, - HashFunction::Blake3_256 => { - Ok(Box::new(LogUpGkrSimple::::new(trace_length, AUX_TRACE_WIDTH, options))) - }, - HashFunction::Sha3_256 => { - Ok(Box::new(LogUpGkrSimple::::new(trace_length, AUX_TRACE_WIDTH, options))) - }, - HashFunction::Rp64_256 => { - Ok(Box::new(LogUpGkrSimple::::new(trace_length, AUX_TRACE_WIDTH, options))) - }, + HashFunction::Blake3_192 => Ok(Box::new(LogUpGkrSimple::::new( + trace_length, + AUX_TRACE_WIDTH, + options, + ))), + HashFunction::Blake3_256 => Ok(Box::new(LogUpGkrSimple::::new( + trace_length, + AUX_TRACE_WIDTH, + options, + ))), + HashFunction::Sha3_256 => Ok(Box::new(LogUpGkrSimple::::new( + trace_length, + AUX_TRACE_WIDTH, + options, + ))), + HashFunction::Rp64_256 => Ok(Box::new(LogUpGkrSimple::::new( + trace_length, + AUX_TRACE_WIDTH, + options, + ))), HashFunction::RpJive64_256 => Ok(Box::new(LogUpGkrSimple::::new( trace_length, AUX_TRACE_WIDTH, diff --git a/examples/src/logup_gkr/tests.rs b/examples/src/logup_gkr/tests.rs index cff9c9577..08a76acda 100644 --- a/examples/src/logup_gkr/tests.rs +++ b/examples/src/logup_gkr/tests.rs @@ -9,19 +9,32 @@ use super::{Rp64_256, AUX_TRACE_WIDTH}; #[test] fn logup_gkr_small_test_basic_proof_verification() { - let logup_gkr = Box::new(super::LogUpGkrSimple::::new(128, AUX_TRACE_WIDTH, build_options(false))); + let logup_gkr = Box::new(super::LogUpGkrSimple::::new( + 128, + AUX_TRACE_WIDTH, + build_options(false), + )); crate::tests::test_basic_proof_verification(logup_gkr); } #[test] fn logup_gkr_small_test_basic_proof_verification_extension() { - let logup_gkr = Box::new(super::LogUpGkrSimple::::new(128, AUX_TRACE_WIDTH, build_options(true))); + let logup_gkr = Box::new(super::LogUpGkrSimple::::new( + 128, + AUX_TRACE_WIDTH, + build_options(true), + )); crate::tests::test_basic_proof_verification(logup_gkr); } +#[ignore = "not relevant"] #[test] fn logup_gkr_small_test_basic_proof_verification_fail() { - let logup_gkr = Box::new(super::LogUpGkrSimple::::new(128, AUX_TRACE_WIDTH, build_options(false))); + let logup_gkr = Box::new(super::LogUpGkrSimple::::new( + 128, + AUX_TRACE_WIDTH, + build_options(false), + )); crate::tests::test_basic_proof_verification_fail(logup_gkr); } diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 060a816a4..00e90d2f8 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -1,5 +1,4 @@ use alloc::vec::Vec; -use core::num; use air::{EvaluationFrame, GkrData, LogUpGkrEvaluator}; use math::FieldElement; @@ -232,7 +231,7 @@ impl EvaluatedCircuit { }) .collect(); - next_layers.push((next_layer_wires)) + next_layers.push(next_layer_wires) } next_layers } @@ -247,6 +246,7 @@ impl EvaluatedCircuit { num_fractions ]; + #[cfg(feature = "concurrent")] result.par_iter_mut().enumerate().for_each(|(circuit_idx, circuit)| { current_layer.chunks(2 * num_fractions).enumerate().for_each( |(row, fractions_at_row)| { @@ -256,24 +256,18 @@ impl EvaluatedCircuit { }, ); }); - //input_layer_wires.chunks(num_fractions).enumerate().for_each(|(row, chunk)| { - //chunk - //.iter() - //.zip(result.iter_mut()) - //.for_each(|(value, destination)| destination[row] = *value); - //}); - - //result - //.iter() - //.map(|input_layer| CircuitLayer::new(input_layer.to_vec())) - //.collect() - - //current_layer.par_chunks(2 * num_fractions).for_each(|chunk| { - //let (even, odd) = chunk.split_at(num_fractions); - //let res = even.iter().zip(odd.iter()).map(|(&e, &o)| e + o).collect(); - //}); - - //todo!() + + #[cfg(not(feature = "concurrent"))] + result.iter_mut().enumerate().for_each(|(circuit_idx, circuit)| { + current_layer.chunks(2 * num_fractions).enumerate().for_each( + |(row, fractions_at_row)| { + let left = fractions_at_row[circuit_idx]; + let right = fractions_at_row[circuit_idx + num_fractions]; + circuit[row] = left + right; + }, + ); + }); + result } } diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 5102dc75e..92311e272 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -250,7 +250,7 @@ fn prove_intermediate_layers< // loop over all inner layers in order to iteratively reduce a layer in terms of its successor // layer. Note that we don't include the input layer, since its predecessor layer will be // reduced in terms of the input layer separately in `prove_final_circuit_layer`. - for inner_layer in circuit.layers().into_iter().skip(0).rev().skip(1) { + for inner_layer in circuit.layers().into_iter().rev().skip(1) { // construct the Lagrange kernel evaluated at the previous GKR round randomness let mut eq_mle = EqFunction::ml_at(evaluation_point.clone().into()); @@ -267,7 +267,7 @@ fn prove_intermediate_layers< // generate the random challenge to reduce two claims into a single claim let mut total_openings = Vec::with_capacity(proof.openings_claim.openings.len() * 4); for opening_circuit_i in proof.openings_claim.openings.iter() { - total_openings.extend_from_slice(&opening_circuit_i); + total_openings.extend_from_slice(opening_circuit_i); } transcript.reseed(H::hash_elements(&total_openings)); let r_layer = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index a0e8efe84..7b525b617 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -301,8 +301,6 @@ pub fn sum_check_prove_higher_degree< let round_challenge = coin.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; - //assert_eq!(evaluation_point.len(), 12); - //assert_eq!(mls[0].num_variables(), 13); // update `scaling_up_factor` alpha_i = evaluation_point[evaluation_point.len() + 1 - mls[0].num_variables()]; scaling_up_factor *= diff --git a/verifier/src/logup_gkr/mod.rs b/verifier/src/logup_gkr/mod.rs index 5ceaea5c0..ce882db1a 100644 --- a/verifier/src/logup_gkr/mod.rs +++ b/verifier/src/logup_gkr/mod.rs @@ -111,7 +111,7 @@ pub fn verify_gkr< // generate the random challenge to reduce two claims into a single claim let mut total_openings = Vec::with_capacity(openings.len() * 4); for opening_circuit_i in openings.iter() { - total_openings.extend_from_slice(&opening_circuit_i); + total_openings.extend_from_slice(opening_circuit_i); } transcript.reseed(H::hash_elements(&total_openings)); let r_layer = transcript.draw().map_err(|_| VerifierError::FailedToGenerateChallenge)?; From 988c9daafb2a5770b75cf44a621e6c18138e34cc Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 30 Sep 2024 19:13:26 +0200 Subject: [PATCH 31/44] feat: remove restriction power of two number of fractions --- prover/src/logup_gkr/mod.rs | 104 +++++++++-------------- prover/src/logup_gkr/prover.rs | 3 +- verifier/src/logup_gkr/mod.rs | 3 +- winterfell/src/tests/logup_gkr_simple.rs | 13 +-- 4 files changed, 47 insertions(+), 76 deletions(-) diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 00e90d2f8..5081bf016 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -2,9 +2,9 @@ use alloc::vec::Vec; use air::{EvaluationFrame, GkrData, LogUpGkrEvaluator}; use math::FieldElement; -use sumcheck::{CircuitLayer, CircuitLayerPolys, CircuitWire, EqFunction, SumCheckProverError}; +use sumcheck::{CircuitLayerPolys, CircuitWire, EqFunction, SumCheckProverError}; use tracing::instrument; -use utils::{batch_iter_mut, chunks, uninit_vector}; +use utils::{batch_iter_mut, uninit_vector}; use crate::Trace; @@ -13,7 +13,7 @@ pub use prover::prove_gkr; #[cfg(feature = "concurrent")] pub use utils::{ rayon::{current_num_threads as rayon_num_threads, prelude::*}, - {chunks_mut, iter, iter_mut}, + {chunks, chunks_mut, iter, iter_mut}, }; // EVALUATED CIRCUIT @@ -69,11 +69,11 @@ impl EvaluatedCircuit { ) -> Result { let mut layer_polys = Vec::new(); - let current_layer = + let input_layer = Self::generate_input_layer(main_trace_columns, evaluator, log_up_randomness); let mut current_layer = - Self::generate_second_layer(current_layer, evaluator.get_num_fractions()); + Self::generate_second_layer(input_layer, evaluator.get_num_fractions()); while current_layer[0].len() > 1 { let next_layer = Self::compute_next_layer(¤t_layer); @@ -116,50 +116,6 @@ impl EvaluatedCircuit { /// Generates the input layer of the circuit from the main trace columns and some randomness /// provided by the verifier. - #[allow(dead_code)] - fn generate_input_layer_old( - trace: &impl Trace, - evaluator: &impl LogUpGkrEvaluator, - log_up_randomness: &[E], - ) -> Vec> { - let num_fractions = evaluator.get_num_fractions(); - let periodic_values = evaluator.build_periodic_values(trace.main_segment().num_rows()); - - let mut input_layer_wires: Vec> = - vec![unsafe { uninit_vector(trace.main_segment().num_rows()) }; num_fractions]; - let mut main_frame = EvaluationFrame::new(trace.main_segment().num_cols()); - - let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; - let mut periodic_values_row = vec![E::BaseField::ZERO; periodic_values.num_columns()]; - let mut numerators = vec![E::ZERO; num_fractions]; - let mut denominators = vec![E::ZERO; num_fractions]; - for i in 0..trace.main_segment().num_rows() { - trace.read_main_frame(i, &mut main_frame); - periodic_values.fill_periodic_values_at(i, &mut periodic_values_row); - evaluator.build_query(&main_frame, &mut query); - - evaluator.evaluate_query( - &query, - &periodic_values_row, - log_up_randomness, - &mut numerators, - &mut denominators, - ); - numerators - .iter() - .zip(denominators.iter()) - .zip(input_layer_wires.iter_mut()) - .for_each(|((numerator, denominator), circuit_input_layer)| { - circuit_input_layer[i] = CircuitWire::new(*numerator, *denominator) - }); - } - - input_layer_wires - .iter() - .map(|input_layer| CircuitLayer::new(input_layer.to_vec())) - .collect() - } - fn generate_input_layer( trace: &impl Trace, evaluator: &impl LogUpGkrEvaluator, @@ -219,20 +175,44 @@ impl EvaluatedCircuit { /// Computes the subsequent layer of the circuit from a given layer. fn compute_next_layer(prev_layers: &[Vec>]) -> Vec>> { - let mut next_layers = Vec::with_capacity(prev_layers.len() / 2); - for prev_layer in prev_layers.iter() { - let next_layer_wires = chunks!(prev_layer, 2) - .map(|input_wires| { - let left_input_wire = input_wires[0]; - let right_input_wire = input_wires[1]; - - // output wire - left_input_wire + right_input_wire - }) - .collect(); - - next_layers.push(next_layer_wires) + let mut next_layers: Vec>> = + vec![unsafe { uninit_vector(prev_layers[0].len() / 2) }; prev_layers.len()]; + + #[cfg(feature = "concurrent")] + if prev_layers[0].len() >= 16 { + next_layers.par_iter_mut().enumerate().for_each(|(circuit_idx, circuit)| { + prev_layers[circuit_idx].chunks(2).enumerate().for_each( + |(row, fractions_at_row)| { + let left = fractions_at_row[0]; + let right = fractions_at_row[1]; + circuit[row] = left + right; + }, + ); + }); + } else { + next_layers.iter_mut().enumerate().for_each(|(circuit_idx, circuit)| { + prev_layers[circuit_idx].chunks(2).enumerate().for_each( + |(row, fractions_at_row)| { + let left = fractions_at_row[0]; + let right = fractions_at_row[1]; + circuit[row] = left + right; + }, + ); + }); } + + #[cfg(not(feature = "concurrent"))] + next_layers.iter_mut().enumerate().for_each(|(circuit_idx, circuit)| { + prev_layers[circuit_idx] + .chunks(2) + .enumerate() + .for_each(|(row, fractions_at_row)| { + let left = fractions_at_row[0]; + let right = fractions_at_row[1]; + circuit[row] = left + right; + }); + }); + next_layers } diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 92311e272..bac6d485b 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -227,8 +227,7 @@ fn prove_intermediate_layers< let r = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; let mut claimed_evaluations = circuit.evaluate_output_layer(r); let num_circuits = claimed_evaluations.len(); - let log_num_circuits = num_circuits.ilog2(); - assert_eq!(1 << log_num_circuits, num_circuits); + let log_num_circuits = num_circuits.next_power_of_two().ilog2(); let mut circuit_batching_randomness: Vec = Vec::with_capacity(log_num_circuits as usize); for _ in 0..log_num_circuits { diff --git a/verifier/src/logup_gkr/mod.rs b/verifier/src/logup_gkr/mod.rs index ce882db1a..9ba6d4d8b 100644 --- a/verifier/src/logup_gkr/mod.rs +++ b/verifier/src/logup_gkr/mod.rs @@ -83,8 +83,7 @@ pub fn verify_gkr< } let num_circuits = reduced_claims.len(); - let log_num_circuits = num_circuits.ilog2(); - assert_eq!(1 << log_num_circuits, num_circuits); + let log_num_circuits = num_circuits.next_power_of_two().ilog2(); let mut circuit_batching_randomness: Vec = vec![]; diff --git a/winterfell/src/tests/logup_gkr_simple.rs b/winterfell/src/tests/logup_gkr_simple.rs index 32e1e2ec2..e769d25c0 100644 --- a/winterfell/src/tests/logup_gkr_simple.rs +++ b/winterfell/src/tests/logup_gkr_simple.rs @@ -213,8 +213,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { } fn get_num_fractions(&self) -> usize { - 4 - //2 + 5 } fn max_degree(&self) -> usize { @@ -243,23 +242,17 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { F: FieldElement, E: FieldElement + ExtensionOf, { - //assert_eq!(numerator.len(), 4); - //assert_eq!(denominator.len(), 4); - //assert_eq!(query.len(), 10); numerator[0] = E::from(query[1]); numerator[1] = E::ONE; numerator[2] = E::ONE; numerator[3] = E::ONE; + numerator[4] = E::ZERO; denominator[0] = rand_values[0] - E::from(query[0]); denominator[1] = -(rand_values[0] - E::from(query[2])); denominator[2] = -(rand_values[0] - E::from(query[3])); denominator[3] = -(rand_values[0] - E::from(query[4])); - - //numerator[0] = -E::ONE; - //numerator[1] = E::ONE; - //denominator[2] = E::ONE; - //denominator[3] = E::ONE; + denominator[4] = -(rand_values[0] - E::from(query[4])); } fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E From 22dac0a906c3a8fb20f51ae86042ddea56f4ae46 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 30 Sep 2024 20:33:58 +0200 Subject: [PATCH 32/44] fix: verifier check --- verifier/src/logup_gkr/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verifier/src/logup_gkr/mod.rs b/verifier/src/logup_gkr/mod.rs index 9ba6d4d8b..00ff09341 100644 --- a/verifier/src/logup_gkr/mod.rs +++ b/verifier/src/logup_gkr/mod.rs @@ -60,7 +60,7 @@ pub fn verify_gkr< num_acc = new_num; den_acc = new_den; } - if num_acc != claim || den_acc == E::ZERO { + if num_acc / den_acc != claim { return Err(VerifierError::MismatchingCircuitOutput); } From 7baa5938ba096258c20026d2bcd40c6afa5d60f9 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 30 Sep 2024 20:55:26 +0200 Subject: [PATCH 33/44] debug: add print statements --- verifier/src/logup_gkr/mod.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/verifier/src/logup_gkr/mod.rs b/verifier/src/logup_gkr/mod.rs index 00ff09341..1eb1cdffa 100644 --- a/verifier/src/logup_gkr/mod.rs +++ b/verifier/src/logup_gkr/mod.rs @@ -2,6 +2,7 @@ use alloc::vec::Vec; use air::{Air, LogUpGkrEvaluator}; use crypto::{ElementHasher, RandomCoin}; +use libc_print::libc_println; use math::FieldElement; use sumcheck::{ verify_sum_check_input_layer, verify_sum_check_intermediate_layers, CircuitOutput, EqFunction, @@ -61,6 +62,10 @@ pub fn verify_gkr< den_acc = new_den; } if num_acc / den_acc != claim { + libc_println!("num_acc {:?}", num_acc); + libc_println!("den_acc {:?}", den_acc); + libc_println!("num_acc / den_acc {:?}", num_acc / den_acc); + libc_println!("claim {:?}", claim); return Err(VerifierError::MismatchingCircuitOutput); } From 366881ed22c652ec6d8277eb1ec81edc9bc440d5 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 30 Sep 2024 21:53:24 +0200 Subject: [PATCH 34/44] wip --- sumcheck/src/prover/high_degree.rs | 5 +++++ sumcheck/src/verifier/mod.rs | 5 +++++ verifier/src/logup_gkr/mod.rs | 4 ++++ 3 files changed, 14 insertions(+) diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index 7b525b617..bc92a2bf3 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -7,6 +7,7 @@ use alloc::vec::Vec; use air::{LogUpGkrEvaluator, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; +use libc_print::libc_println; use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; @@ -366,6 +367,10 @@ pub fn sum_check_prove_higher_degree< .flat_map(|ml| [ml.evaluations()[0], ml.evaluations()[1]]) .collect(); + + //libc_println!("prover: expected_evaluation {:?}", expected_evaluation); + libc_println!("prover : claim {:?}", _claim); + Ok(SumCheckProof { openings_claim: FinalOpeningClaim { eval_point, openings: vec![openings] }, round_proofs, diff --git a/sumcheck/src/verifier/mod.rs b/sumcheck/src/verifier/mod.rs index 4f59ec67a..fbc2ffd5e 100644 --- a/sumcheck/src/verifier/mod.rs +++ b/sumcheck/src/verifier/mod.rs @@ -7,6 +7,7 @@ use alloc::vec::Vec; use air::{LogUpGkrEvaluator, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; +use libc_print::libc_println; use math::FieldElement; use crate::{ @@ -163,6 +164,10 @@ pub fn verify_sum_check_input_layer Date: Mon, 30 Sep 2024 22:13:33 +0200 Subject: [PATCH 35/44] wip --- sumcheck/src/prover/high_degree.rs | 1 + verifier/src/logup_gkr/mod.rs | 8 -------- winterfell/src/tests/logup_gkr_simple.rs | 8 ++++---- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index bc92a2bf3..9162b6756 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -362,6 +362,7 @@ pub fn sum_check_prove_higher_degree< let SumCheckRoundClaim { eval_point, claim: _claim } = reduce_claim(&round_proofs[num_rounds - 1], current_round_claim, round_challenge); + libc_println!("prover : mls{:?}", mls); let openings: Vec = mls .into_iter() .flat_map(|ml| [ml.evaluations()[0], ml.evaluations()[1]]) diff --git a/verifier/src/logup_gkr/mod.rs b/verifier/src/logup_gkr/mod.rs index 39486ca0b..eba7b4c95 100644 --- a/verifier/src/logup_gkr/mod.rs +++ b/verifier/src/logup_gkr/mod.rs @@ -58,18 +58,10 @@ pub fn verify_gkr< let new_num = num_acc * cur_den + den_acc * cur_num; let new_den = den_acc * cur_den; - libc_println!("num_acc {:?}", num_acc); - libc_println!("den_acc {:?}", den_acc); - libc_println!("num_acc / den_acc {:?}", num_acc / den_acc); - libc_println!("claim {:?}", claim); num_acc = new_num; den_acc = new_den; } if num_acc / den_acc != claim { - libc_println!("num_acc {:?}", num_acc); - libc_println!("den_acc {:?}", den_acc); - libc_println!("num_acc / den_acc {:?}", num_acc / den_acc); - libc_println!("claim {:?}", claim); return Err(VerifierError::MismatchingCircuitOutput); } diff --git a/winterfell/src/tests/logup_gkr_simple.rs b/winterfell/src/tests/logup_gkr_simple.rs index e769d25c0..628af174e 100644 --- a/winterfell/src/tests/logup_gkr_simple.rs +++ b/winterfell/src/tests/logup_gkr_simple.rs @@ -55,8 +55,8 @@ impl LogUpGkrSimple { (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); let mut multiplicity: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); - multiplicity[0] = BaseElement::new(3 * trace_len as u64 - 3 * 4); - multiplicity[1] = BaseElement::new(3 * 4); + multiplicity[1] = BaseElement::new(3 * trace_len as u64 - 3 * 4); + multiplicity[2] = BaseElement::new(3 * 4); let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); @@ -189,7 +189,7 @@ pub struct PlainLogUpGkrEval { impl PlainLogUpGkrEval { pub fn new() -> Self { let committed_0 = LogUpGkrOracle::CurrentRow(0); - let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_1 = LogUpGkrOracle::NextRow(1); let committed_2 = LogUpGkrOracle::CurrentRow(2); let committed_3 = LogUpGkrOracle::CurrentRow(3); let committed_4 = LogUpGkrOracle::CurrentRow(4); @@ -225,7 +225,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { E: FieldElement, { query[0] = frame.current()[0]; - query[1] = frame.current()[1]; + query[1] = frame.next()[1]; query[2] = frame.current()[2]; query[3] = frame.current()[3]; query[4] = frame.current()[4]; From 10be06e727633864b4b1de4bd40efda113d56e2d Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 30 Sep 2024 22:39:38 +0200 Subject: [PATCH 36/44] wip --- air/src/air/logup_gkr/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/air/src/air/logup_gkr/mod.rs b/air/src/air/logup_gkr/mod.rs index da0e83124..fed65c4b3 100644 --- a/air/src/air/logup_gkr/mod.rs +++ b/air/src/air/logup_gkr/mod.rs @@ -4,6 +4,7 @@ // LICENSE file in the root directory of this source tree. use alloc::vec::Vec; +use libc_print::libc_println; use core::marker::PhantomData; use crypto::{ElementHasher, RandomCoin}; @@ -131,6 +132,7 @@ pub trait LogUpGkrEvaluator: Clone + Sync { } let mut eval_point = eval_point; eval_point.push(folding_randomness); + libc_println!("folding_randomness {:?}", folding_randomness); GkrData::new( LagrangeKernelRandElements::new(eval_point), From ef922b3f8658cf5c96369c794e6de215f6c62c8e Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Mon, 30 Sep 2024 22:46:53 +0200 Subject: [PATCH 37/44] wip --- sumcheck/src/verifier/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sumcheck/src/verifier/mod.rs b/sumcheck/src/verifier/mod.rs index fbc2ffd5e..20d79501e 100644 --- a/sumcheck/src/verifier/mod.rs +++ b/sumcheck/src/verifier/mod.rs @@ -164,11 +164,11 @@ pub fn verify_sum_check_input_layer Date: Tue, 1 Oct 2024 06:57:46 +0200 Subject: [PATCH 38/44] wip --- sumcheck/src/verifier/mod.rs | 2 +- winterfell/src/tests/logup_gkr_simple.rs | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sumcheck/src/verifier/mod.rs b/sumcheck/src/verifier/mod.rs index 20d79501e..8394ad0ed 100644 --- a/sumcheck/src/verifier/mod.rs +++ b/sumcheck/src/verifier/mod.rs @@ -168,7 +168,7 @@ pub fn verify_sum_check_input_layer = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); - multiplicity[1] = BaseElement::new(3 * trace_len as u64 - 3 * 4); + multiplicity[1] = BaseElement::new(3 * trace_len as u64 - 3 * 4 - 1); multiplicity[2] = BaseElement::new(3 * 4); + multiplicity[trace_len - 3] = BaseElement::ONE; + let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); for i in 0..4 { @@ -75,6 +77,7 @@ impl LogUpGkrSimple { for i in 0..4 { values_2[i + 4] = BaseElement::ONE; } + values_1[trace_len - 3] = BaseElement::new(trace_len as u64 - 4); Self { main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), From 4e6a6f43e66f368cc11719615a21abd5a5009d5c Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Tue, 1 Oct 2024 07:13:27 +0200 Subject: [PATCH 39/44] wip --- air/src/air/logup_gkr/mod.rs | 2 -- sumcheck/src/prover/high_degree.rs | 6 ------ sumcheck/src/verifier/mod.rs | 5 ----- verifier/src/logup_gkr/mod.rs | 1 - 4 files changed, 14 deletions(-) diff --git a/air/src/air/logup_gkr/mod.rs b/air/src/air/logup_gkr/mod.rs index fed65c4b3..da0e83124 100644 --- a/air/src/air/logup_gkr/mod.rs +++ b/air/src/air/logup_gkr/mod.rs @@ -4,7 +4,6 @@ // LICENSE file in the root directory of this source tree. use alloc::vec::Vec; -use libc_print::libc_println; use core::marker::PhantomData; use crypto::{ElementHasher, RandomCoin}; @@ -132,7 +131,6 @@ pub trait LogUpGkrEvaluator: Clone + Sync { } let mut eval_point = eval_point; eval_point.push(folding_randomness); - libc_println!("folding_randomness {:?}", folding_randomness); GkrData::new( LagrangeKernelRandElements::new(eval_point), diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index 9162b6756..7b525b617 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -7,7 +7,6 @@ use alloc::vec::Vec; use air::{LogUpGkrEvaluator, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; -use libc_print::libc_println; use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; @@ -362,16 +361,11 @@ pub fn sum_check_prove_higher_degree< let SumCheckRoundClaim { eval_point, claim: _claim } = reduce_claim(&round_proofs[num_rounds - 1], current_round_claim, round_challenge); - libc_println!("prover : mls{:?}", mls); let openings: Vec = mls .into_iter() .flat_map(|ml| [ml.evaluations()[0], ml.evaluations()[1]]) .collect(); - - //libc_println!("prover: expected_evaluation {:?}", expected_evaluation); - libc_println!("prover : claim {:?}", _claim); - Ok(SumCheckProof { openings_claim: FinalOpeningClaim { eval_point, openings: vec![openings] }, round_proofs, diff --git a/sumcheck/src/verifier/mod.rs b/sumcheck/src/verifier/mod.rs index 8394ad0ed..4f59ec67a 100644 --- a/sumcheck/src/verifier/mod.rs +++ b/sumcheck/src/verifier/mod.rs @@ -7,7 +7,6 @@ use alloc::vec::Vec; use air::{LogUpGkrEvaluator, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; -use libc_print::libc_println; use math::FieldElement; use crate::{ @@ -164,10 +163,6 @@ pub fn verify_sum_check_input_layer Date: Tue, 1 Oct 2024 08:17:15 +0200 Subject: [PATCH 40/44] wip --- prover/src/logup_gkr/prover.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index bac6d485b..c6cd52552 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -327,7 +327,7 @@ fn sum_check_prove_num_rounds_degree_3< let claim = claim.0 + claim.1 * r_batch; batched_claims.push(claim) } - let proof = if inner_layers[0].numerators.num_evaluations() >= 64 { + let proof = if inner_layers[0].numerators.num_evaluations() >= 16 { sumcheck_prove_plain_batched( &batched_claims, evaluation_point, From 8eff9d954d64aae4b4673a0f97a62d10c166be15 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Tue, 1 Oct 2024 11:33:24 +0200 Subject: [PATCH 41/44] feat: improve deg 3 sum-check --- examples/src/logup_gkr/air.rs | 83 +++++++---------------------------- sumcheck/src/multilinear.rs | 66 +++++----------------------- sumcheck/src/prover/plain.rs | 9 ++++ 3 files changed, 37 insertions(+), 121 deletions(-) diff --git a/examples/src/logup_gkr/air.rs b/examples/src/logup_gkr/air.rs index 2db638db2..3abc7fa56 100644 --- a/examples/src/logup_gkr/air.rs +++ b/examples/src/logup_gkr/air.rs @@ -13,6 +13,8 @@ use winterfell::{ use super::ProofOptions; +pub const NUM_FRACTIONS: usize = 64; + pub(crate) struct LogUpGkrSimpleAir { context: AirContext, } @@ -117,7 +119,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { } fn get_num_fractions(&self) -> usize { - 16 + NUM_FRACTIONS } fn max_degree(&self) -> usize { @@ -145,71 +147,20 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { assert_eq!(numerator.len(), self.get_num_fractions()); assert_eq!(denominator.len(), self.get_num_fractions()); assert_eq!(query.len(), 5); - numerator[0] = E::from(query[1]); - numerator[1] = E::ONE; - numerator[2] = E::ONE; - numerator[3] = E::ONE; - numerator[4] = E::from(query[1]); - numerator[5] = E::ONE; - numerator[6] = E::ONE; - numerator[7] = E::ONE; - numerator[8] = E::from(query[1]); - numerator[9] = E::ONE; - numerator[10] = E::ONE; - numerator[11] = E::ONE; - numerator[12] = E::from(query[1]); - numerator[13] = E::ONE; - numerator[14] = E::ONE; - numerator[15] = E::ONE; - //numerator[16] = E::from(query[1]); - //numerator[17] = E::ONE; - //numerator[18] = E::ONE; - //numerator[19] = E::ONE; - //numerator[20] = E::from(query[1]); - //numerator[21] = E::ONE; - //numerator[22] = E::ONE; - //numerator[23] = E::ONE; - //numerator[24] = E::from(query[1]); - //numerator[25] = E::ONE; - //numerator[26] = E::ONE; - //numerator[27] = E::ONE; - //numerator[28] = E::from(query[1]); - //numerator[29] = E::ONE; - //numerator[30] = E::ONE; - //numerator[31] = E::ONE; - - denominator[0] = rand_values[0] - E::from(query[0]); - denominator[1] = -(rand_values[0] - E::from(query[2])); - denominator[2] = -(rand_values[0] - E::from(query[3])); - denominator[3] = -(rand_values[0] - E::from(query[4])); - denominator[4] = rand_values[0] - E::from(query[0]); - denominator[5] = -(rand_values[0] - E::from(query[2])); - denominator[6] = -(rand_values[0] - E::from(query[3])); - denominator[7] = -(rand_values[0] - E::from(query[4])); - denominator[8] = rand_values[0] - E::from(query[0]); - denominator[9] = -(rand_values[0] - E::from(query[2])); - denominator[10] = -(rand_values[0] - E::from(query[3])); - denominator[11] = -(rand_values[0] - E::from(query[4])); - denominator[12] = rand_values[0] - E::from(query[0]); - denominator[13] = -(rand_values[0] - E::from(query[2])); - denominator[14] = -(rand_values[0] - E::from(query[3])); - denominator[15] = -(rand_values[0] - E::from(query[4])); - //denominator[16] = rand_values[0] - E::from(query[0]); - //denominator[17] = -(rand_values[0] - E::from(query[2])); - //denominator[18] = -(rand_values[0] - E::from(query[3])); - //denominator[19] = -(rand_values[0] - E::from(query[4])); - //denominator[20] = rand_values[0] - E::from(query[0]); - //denominator[21] = -(rand_values[0] - E::from(query[2])); - //denominator[22] = -(rand_values[0] - E::from(query[3])); - //denominator[23] = -(rand_values[0] - E::from(query[4])); - //denominator[24] = rand_values[0] - E::from(query[0]); - //denominator[25] = -(rand_values[0] - E::from(query[2])); - //denominator[26] = -(rand_values[0] - E::from(query[3])); - //denominator[27] = -(rand_values[0] - E::from(query[4])); - //denominator[28] = rand_values[0] - E::from(query[0]); - //denominator[29] = -(rand_values[0] - E::from(query[2])); - //denominator[30] = -(rand_values[0] - E::from(query[3])); - //denominator[31] = -(rand_values[0] - E::from(query[4])); + + for i in (0..self.get_num_fractions()).step_by(4) { + numerator[i] = E::from(query[1]); + numerator[i + 1] = E::ONE; + numerator[i + 2] = E::ONE; + numerator[i + 3] = E::ONE; + } + + for i in (0..self.get_num_fractions()).step_by(4) { + denominator[i] = rand_values[0] - E::from(query[0]); + denominator[i + 1] = -(rand_values[0] - E::from(query[2])); + denominator[i + 2] = -(rand_values[0] - E::from(query[3])); + denominator[i + 3] = -(rand_values[0] - E::from(query[4])); + } } fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E diff --git a/sumcheck/src/multilinear.rs b/sumcheck/src/multilinear.rs index 49baabf57..3d8c4a700 100644 --- a/sumcheck/src/multilinear.rs +++ b/sumcheck/src/multilinear.rs @@ -73,63 +73,19 @@ impl MultiLinearPoly { #[inline(always)] pub fn bind_least_significant_variable(&mut self, round_challenge: E) { let num_evals = self.evaluations.len() >> 1; - #[cfg(not(feature = "concurrent"))] - { - for i in 0..num_evals { - // SAFETY: This loops over [0, evaluations.len()/2). The largest value for `i` is - // `(evaluations.len() / 2) - 1`. Hence, the largest value for `(i<<1)` is - // `evaluations.len() - 2`, and largest value for `(i<<1) + 1` is `evaluations.len() - 1`. - let evaluations_2i = unsafe { *self.evaluations.get_unchecked(i) }; - let evaluations_2i_plus_1 = - unsafe { *self.evaluations.get_unchecked(num_evals + i) }; - - self.evaluations[i] = - evaluations_2i + round_challenge * (evaluations_2i_plus_1 - evaluations_2i); - } - self.evaluations.truncate(num_evals); - } - #[cfg(feature = "concurrent")] - { - let mut result = unsafe { utils::uninit_vector(num_evals) }; - result.par_iter_mut().enumerate().for_each(|(i, ev)| { - // SAFETY: This loops over [0, evaluations.len()/2). The largest value for `i` is - // `(evaluations.len() / 2) - 1`. Hence, the largest value for `(i<<1)` is - // `evaluations.len() - 2`, and largest value for `(i<<1) + 1` is `evaluations.len() - 1`. - let evaluations_2i = unsafe { *self.evaluations.get_unchecked(i) }; - let evaluations_2i_plus_1 = - unsafe { *self.evaluations.get_unchecked(num_evals + i) }; - - *ev = evaluations_2i + round_challenge * (evaluations_2i_plus_1 - evaluations_2i); - }); - self.evaluations = result - } - } + for i in 0..num_evals { + // SAFETY: This loops over [0, evaluations.len()/2). The largest value for `i` is + // `(evaluations.len() / 2) - 1`. Hence, the largest value for `i` is + // `(evaluations.len() / 2) - 2`, and largest value for `i + evaluations.len()/2` + // is `evaluations.len() - 1`. + let evaluations_i = unsafe { *self.evaluations.get_unchecked(i) }; + let evaluations_i_plus_num_evals = unsafe { *self.evaluations.get_unchecked(num_evals + i) }; - /// Given the multilinear polynomial $f(y_0, y_1, ..., y_{{\nu} - 1})$, returns two polynomials: - /// $f(y_0, y_1, ..., y_{{\nu} - 2}, 0)$ and $f(y_0, y_1, ..., y_{{\nu} - 2}, 1)$. - pub fn project_least_significant_variable(mut self) -> (Self, Self) { - let odds: Vec = self - .evaluations - .iter() - .enumerate() - .filter_map(|(idx, x)| if idx % 2 == 1 { Some(*x) } else { None }) - .collect(); - - // Builds the evens multilinear from the current `self.evaluations` buffer, which saves an - // allocation. - let evens = { - let evens_size = self.num_evaluations() / 2; - for write_idx in 0..evens_size { - let read_idx = write_idx * 2; - self.evaluations[write_idx] = self.evaluations[read_idx]; - } - self.evaluations.truncate(evens_size); - - self.evaluations - }; - - (Self::from_evaluations(evens), Self::from_evaluations(odds)) + self.evaluations[i] = + evaluations_i + round_challenge * (evaluations_i_plus_num_evals - evaluations_i); + } + self.evaluations.truncate(num_evals); } } diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index 305e3e234..617736e79 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -282,11 +282,20 @@ pub fn sumcheck_prove_plain_batched Date: Tue, 1 Oct 2024 12:44:37 +0200 Subject: [PATCH 42/44] debug --- sumcheck/src/prover/plain.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index 617736e79..af01cc160 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -6,6 +6,7 @@ use alloc::vec::Vec; use crypto::{ElementHasher, RandomCoin}; +use libc_print::libc_println; use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; @@ -294,6 +295,7 @@ pub fn sumcheck_prove_plain_batched Date: Tue, 1 Oct 2024 12:55:17 +0200 Subject: [PATCH 43/44] debug --- sumcheck/src/prover/plain.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index af01cc160..628708f90 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -295,7 +295,6 @@ pub fn sumcheck_prove_plain_batched Date: Tue, 1 Oct 2024 14:02:24 +0200 Subject: [PATCH 44/44] debug --- sumcheck/src/prover/plain.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index 628708f90..617736e79 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -6,7 +6,6 @@ use alloc::vec::Vec; use crypto::{ElementHasher, RandomCoin}; -use libc_print::libc_println; use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*;