From d507eb99b1495791d9c832ab2681007316226d0c Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 21 Aug 2024 08:16:40 +0200 Subject: [PATCH 01/19] add `winter-sumcheck` crate with the Sumcheck IOP for LogUp-GKR (#295) --- Cargo.toml | 4 +- air/src/air/logup_gkr.rs | 89 ++++ air/src/air/mod.rs | 3 + air/src/lib.rs | 4 +- sumcheck/Cargo.toml | 49 +++ sumcheck/README.md | 24 ++ sumcheck/benches/bind_variable.rs | 90 ++++ sumcheck/benches/eq_function.rs | 96 +++++ sumcheck/benches/sum_check_high_degree.rs | 160 +++++++ sumcheck/benches/sum_check_plain.rs | 66 +++ sumcheck/src/lib.rs | 280 +++++++++++++ sumcheck/src/multilinear.rs | 381 +++++++++++++++++ sumcheck/src/prover/error.rs | 15 + sumcheck/src/prover/high_degree.rs | 481 ++++++++++++++++++++++ sumcheck/src/prover/mod.rs | 13 + sumcheck/src/prover/plain.rs | 216 ++++++++++ sumcheck/src/univariate.rs | 295 +++++++++++++ sumcheck/src/verifier/mod.rs | 149 +++++++ 18 files changed, 2411 insertions(+), 4 deletions(-) create mode 100644 air/src/air/logup_gkr.rs create mode 100644 sumcheck/Cargo.toml create mode 100644 sumcheck/README.md create mode 100644 sumcheck/benches/bind_variable.rs create mode 100644 sumcheck/benches/eq_function.rs create mode 100644 sumcheck/benches/sum_check_high_degree.rs create mode 100644 sumcheck/benches/sum_check_plain.rs create mode 100644 sumcheck/src/lib.rs create mode 100644 sumcheck/src/multilinear.rs create mode 100644 sumcheck/src/prover/error.rs create mode 100644 sumcheck/src/prover/high_degree.rs create mode 100644 sumcheck/src/prover/mod.rs create mode 100644 sumcheck/src/prover/plain.rs create mode 100644 sumcheck/src/univariate.rs create mode 100644 sumcheck/src/verifier/mod.rs diff --git a/Cargo.toml b/Cargo.toml index b0ed3f07c..7a5f28ff7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,8 +10,8 @@ members = [ "prover", "verifier", "winterfell", - "examples" -] + "examples", + "sumcheck",] resolver = "2" [profile.release] diff --git a/air/src/air/logup_gkr.rs b/air/src/air/logup_gkr.rs new file mode 100644 index 000000000..98054c938 --- /dev/null +++ b/air/src/air/logup_gkr.rs @@ -0,0 +1,89 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use alloc::vec::Vec; + +use math::{ExtensionOf, FieldElement, StarkField, ToElements}; + +use super::EvaluationFrame; + +/// A trait containing the necessary information in order to run the LogUp-GKR protocol of [1]. +/// +/// The trait contains useful information for running the GKR protocol as well as for implementing +/// the univariate IOP for multi-linear evaluation of Section 5 in [1] for the final evaluation +/// check resulting from GKR. +/// +/// [1]: https://eprint.iacr.org/2023/1284 +pub trait LogUpGkrEvaluator: Clone + Sync { + /// Defines the base field of the evaluator. + type BaseField: StarkField; + + /// Public inputs need to compute the final claim. + type PublicInputs: ToElements + Send; + + /// Gets a list of all oracles involved in LogUp-GKR; this is intended to be used in construction of + /// MLEs. + fn get_oracles(&self) -> Vec>; + + /// Returns the number of random values needed to evaluate a query. + fn get_num_rand_values(&self) -> usize; + + /// Returns the number of fractions in the LogUp-GKR statement. + fn get_num_fractions(&self) -> usize; + + /// Returns the maximal degree of the multi-variate associated to the input layer. + /// + /// This is equal to the max of $1 + deg_k(\text{numerator}_i) * deg_k(\text{denominator}_j)$ where + /// $i$ and $j$ range over the number of numerators and denominators, respectively, and $deg_k$ + /// is the degree of a multi-variate polynomial in its $k$-th variable. + fn max_degree(&self) -> usize; + + /// Builds a query from the provided main trace frame and periodic values. + /// + /// Note: it should be possible to provide an implementation of this method based on the + /// information returned from `get_oracles()`. However, this implementation is likely to be + /// expensive compared to the hand-written implementation. However, we could provide a test + /// which verifies that `get_oracles()` and `build_query()` methods are consistent. + fn build_query(&self, frame: &EvaluationFrame, periodic_values: &[E], query: &mut [E]) + where + E: FieldElement; + + /// Evaluates the provided query and writes the results into the numerators and denominators. + /// + /// Note: it is also possible to combine `build_query()` and `evaluate_query()` into a single + /// method to avoid the need to first build the query struct and then evaluate it. However: + /// - We assume that the compiler will be able to optimize this away. + /// - Merging the methods will make it more difficult avoid inconsistencies between + /// `evaluate_query()` and `get_oracles()` methods. + fn evaluate_query( + &self, + query: &[F], + logup_randomness: &[E], + numerators: &mut [E], + denominators: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf; + + /// Computes the final claim for the LogUp-GKR circuit. + /// + /// The default implementation of this method returns E::ZERO as it is expected that the + /// fractional sums will cancel out. However, in cases when some boundary conditions need to + /// be imposed on the LogUp-GKR relations, this method can be overridden to compute the final + /// expected claim. + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + E::ZERO + } +} + +#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord)] +pub enum LogUpGkrOracle { + CurrentRow(usize), + NextRow(usize), + PeriodicValue(Vec), +} diff --git a/air/src/air/mod.rs b/air/src/air/mod.rs index 53a59fa5a..07f38cce1 100644 --- a/air/src/air/mod.rs +++ b/air/src/air/mod.rs @@ -34,6 +34,9 @@ pub use lagrange::{ LagrangeKernelRandElements, LagrangeKernelTransitionConstraints, }; +mod logup_gkr; +pub use logup_gkr::{LogUpGkrEvaluator, LogUpGkrOracle}; + mod coefficients; pub use coefficients::{ ConstraintCompositionCoefficients, DeepCompositionCoefficients, diff --git a/air/src/lib.rs b/air/src/lib.rs index 539a812d9..aaede0bda 100644 --- a/air/src/lib.rs +++ b/air/src/lib.rs @@ -47,6 +47,6 @@ pub use air::{ DeepCompositionCoefficients, EvaluationFrame, GkrRandElements, GkrVerifier, LagrangeConstraintsCompositionCoefficients, LagrangeKernelBoundaryConstraint, LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LagrangeKernelRandElements, - LagrangeKernelTransitionConstraints, TraceInfo, TransitionConstraintDegree, - TransitionConstraints, + LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, TraceInfo, + TransitionConstraintDegree, TransitionConstraints, }; diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml new file mode 100644 index 000000000..97dc79f3a --- /dev/null +++ b/sumcheck/Cargo.toml @@ -0,0 +1,49 @@ +[package] +name = "winter-sumcheck" +version = "0.1.0" +description = "Implementation of the sum-check protocol for the LogUp-GKR protocol" +authors = ["winterfell contributors"] +readme = "README.md" +license = "MIT" +repository = "https://github.com/novifinancial/winterfell" +documentation = "https://docs.rs/winter-sumcheck/0.1.0" +categories = ["cryptography", "no-std"] +keywords = ["crypto", "sumcheck", "iop"] +edition = "2021" +rust-version = "1.78" + +[[bench]] +name = "sum_check_plain" +harness = false + +[[bench]] +name = "sum_check_high_degree" +harness = false + +[[bench]] +name = "eq_function" +harness = false +required-features = ["concurrent"] + +[[bench]] +name = "bind_variable" +harness = false +required-features = ["concurrent"] + +[features] +concurrent = ["utils/concurrent", "dep:rayon", "std"] +default = ["std"] +std = ["utils/std"] + +[dependencies] +air = { version = "0.9", path = "../air", package = "winter-air", default-features = false } +crypto = { version = "0.9", path = "../crypto", package = "winter-crypto", default-features = false } +math = { version = "0.9", path = "../math", package = "winter-math", default-features = false } +utils = { version = "0.9", path = "../utils/core", package = "winter-utils", default-features = false } +rayon = { version = "1.8", optional = true } +smallvec = { version = "1.13", default-features = false } +thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false } + +[dev-dependencies] +criterion = "0.5" +rand-utils = { version = "0.9", path = "../utils/rand", package = "winter-rand-utils" } \ No newline at end of file diff --git a/sumcheck/README.md b/sumcheck/README.md new file mode 100644 index 000000000..be6734aae --- /dev/null +++ b/sumcheck/README.md @@ -0,0 +1,24 @@ +# Winter sum-check +This crate contains an implementation of the sum-check protocol intended to be used for [LogUp-GKR](https://eprint.iacr.org/2023/1284) by the Winterfell STARK prover and verifier. + +The crate provides two implementations of the sum-check protocol: + +* An implementation for the sum-check protocol as used in [LogUp-GKR](https://eprint.iacr.org/2023/1284). +* An implementation which generalizes the previous one to the case where the numerators and denominators appearing in the fractional sum-checks in Section 3 of [LogUp-GKR](https://eprint.iacr.org/2023/1284) can be non-linear compositions of multi-linear polynomials. + +The first implementation is intended to be used by the GKR protocol for proving the correct evaluation of all of the layers of the fractionl sum circuit except for the input layer. The second implementation is intended to be used for proving the correct evaluation of the input layer. + + +## Crate features +This crate can be compiled with the following features: + +* `std` - enabled by default and relies on the Rust standard library. +* `concurrent` - implies `std` and also re-exports `rayon` crate and enables multi-threaded execution for some of the crate functions. +* `no_std` - does not rely on Rust's standard library and enables compilation to WebAssembly. + +To compile with `no_std`, disable default features via `--no-default-features` flag. + +License +------- + +This project is [MIT licensed](../LICENSE). \ No newline at end of file diff --git a/sumcheck/benches/bind_variable.rs b/sumcheck/benches/bind_variable.rs new file mode 100644 index 000000000..f7e82126c --- /dev/null +++ b/sumcheck/benches/bind_variable.rs @@ -0,0 +1,90 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::time::Duration; + +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use math::{fields::f64::BaseElement, FieldElement}; +use rand_utils::{rand_value, rand_vector}; +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; + +const POLY_SIZE: [usize; 2] = [1 << 18, 1 << 20]; + +fn bind_variable_serial(c: &mut Criterion) { + let mut group = c.benchmark_group("Bind variable evaluations"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + for &poly_size in POLY_SIZE.iter() { + group.bench_function(BenchmarkId::new("serial", poly_size), |b| { + b.iter_batched( + || { + let random_challenge: BaseElement = rand_value(); + let poly_evals: Vec = rand_vector(poly_size); + (random_challenge, poly_evals) + }, + |(random_challenge, poly_evals)| { + let mut poly_evals = poly_evals; + bind_least_significant_variable_serial(&mut poly_evals, random_challenge) + }, + BatchSize::SmallInput, + ) + }); + } +} + +fn bind_variable_parallel(c: &mut Criterion) { + let mut group = c.benchmark_group("Bind variable function evaluations"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + for &poly_size in POLY_SIZE.iter() { + group.bench_function(BenchmarkId::new("parallel", poly_size), |b| { + b.iter_batched( + || { + let random_challenge: BaseElement = rand_value(); + let poly_evals: Vec = rand_vector(poly_size); + (random_challenge, poly_evals) + }, + |(random_challenge, poly_evals)| { + let mut poly_evals = poly_evals; + bind_least_significant_variable_parallel(&mut poly_evals, random_challenge) + }, + BatchSize::SmallInput, + ) + }); + } +} + +fn bind_least_significant_variable_serial( + evaluations: &mut Vec, + round_challenge: E, +) { + let num_evals = evaluations.len() >> 1; + + for i in 0..num_evals { + evaluations[i] = evaluations[i << 1] + + round_challenge * (evaluations[(i << 1) + 1] - evaluations[i << 1]); + } + evaluations.truncate(num_evals); +} + +fn bind_least_significant_variable_parallel( + evaluations: &mut Vec, + round_challenge: E, +) { + let num_evals = evaluations.len() >> 1; + + let mut result = unsafe { utils::uninit_vector(num_evals) }; + result.par_iter_mut().enumerate().for_each(|(i, ev)| { + *ev = evaluations[i << 1] + + round_challenge * (evaluations[(i << 1) + 1] - evaluations[i << 1]) + }); + *evaluations = result +} + +criterion_group!(group, bind_variable_serial, bind_variable_parallel); +criterion_main!(group); diff --git a/sumcheck/benches/eq_function.rs b/sumcheck/benches/eq_function.rs new file mode 100644 index 000000000..86e6cad98 --- /dev/null +++ b/sumcheck/benches/eq_function.rs @@ -0,0 +1,96 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::time::Duration; + +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use math::{fields::f64::BaseElement, FieldElement}; +use rand_utils::rand_vector; +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; + +const LOG_POLY_SIZE: [usize; 2] = [18, 20]; + +fn evaluate_eq_serial(c: &mut Criterion) { + let mut group = c.benchmark_group("EQ function evaluations"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + for &log_poly_size in LOG_POLY_SIZE.iter() { + group.bench_function(BenchmarkId::new("serial", log_poly_size), |b| { + b.iter_batched( + || { + let randomness: Vec = rand_vector(log_poly_size); + randomness + }, + |rand| eq_evaluations(&rand), + BatchSize::SmallInput, + ) + }); + } +} + +fn evaluate_eq_parallel(c: &mut Criterion) { + let mut group = c.benchmark_group("EQ function evaluations"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + for &log_poly_size in LOG_POLY_SIZE.iter() { + group.bench_function(BenchmarkId::new("parallel", log_poly_size), |b| { + b.iter_batched( + || { + let randomness: Vec = rand_vector(log_poly_size); + randomness + }, + |rand| eq_evaluations_par(&rand), + BatchSize::SmallInput, + ) + }); + } +} + +fn eq_evaluations(query: &[E]) -> Vec { + let n = 1 << query.len(); + let mut evals = unsafe { utils::uninit_vector(n) }; + + let mut size = 1; + evals[0] = E::ONE; + for r_i in query.iter() { + let (left_evals, right_evals) = evals.split_at_mut(size); + left_evals.iter_mut().zip(right_evals.iter_mut()).for_each(|(left, right)| { + let factor = *left; + *right = factor * *r_i; + *left -= *right; + }); + + size *= 2; + } + evals +} + +fn eq_evaluations_par(query: &[E]) -> Vec { + let n = 1 << query.len(); + let mut evals = unsafe { utils::uninit_vector(n) }; + + let mut size = 1; + evals[0] = E::ONE; + for r_i in query.iter() { + let (left_evals, right_evals) = evals.split_at_mut(size); + left_evals + .par_iter_mut() + .zip(right_evals.par_iter_mut()) + .for_each(|(left, right)| { + let factor = *left; + *right = factor * *r_i; + *left -= *right; + }); + + size <<= 1; + } + evals +} + +criterion_group!(group, evaluate_eq_serial, evaluate_eq_parallel); +criterion_main!(group); diff --git a/sumcheck/benches/sum_check_high_degree.rs b/sumcheck/benches/sum_check_high_degree.rs new file mode 100644 index 000000000..3db6a37e3 --- /dev/null +++ b/sumcheck/benches/sum_check_high_degree.rs @@ -0,0 +1,160 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::{marker::PhantomData, time::Duration}; + +use air::{EvaluationFrame, LogUpGkrEvaluator, LogUpGkrOracle}; +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use crypto::{hashers::Blake3_192, DefaultRandomCoin, RandomCoin}; +use math::{fields::f64::BaseElement, ExtensionOf, FieldElement}; +use rand_utils::{rand_value, rand_vector}; +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; +use winter_sumcheck::{sum_check_prove_higher_degree, MultiLinearPoly}; + +const LOG_POLY_SIZE: [usize; 2] = [18, 20]; + +fn sum_check_high_degree(c: &mut Criterion) { + let mut group = c.benchmark_group("Sum-check prover high degree"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + for &log_poly_size in LOG_POLY_SIZE.iter() { + group.bench_function(BenchmarkId::new("", log_poly_size), |b| { + b.iter_batched( + || { + let logup_randomness = rand_vector(1); + let evaluator = PlainLogUpGkrEval::::default(); + let transcript = + DefaultRandomCoin::>::new(&[BaseElement::ZERO; 4]); + ( + setup_sum_check::(log_poly_size), + evaluator, + logup_randomness, + transcript, + ) + }, + |( + (claim, r_batch, rand_pt, (ml0, ml1, ml2, ml3, ml4)), + evaluator, + logup_randomness, + transcript, + )| { + let mls = vec![ml0, ml1, ml2, ml3, ml4]; + let mut transcript = transcript; + + sum_check_prove_higher_degree( + &evaluator, + rand_pt, + claim, + r_batch, + logup_randomness, + mls, + &mut transcript, + ) + }, + BatchSize::SmallInput, + ) + }); + } +} + +#[allow(clippy::too_many_arguments)] +#[allow(clippy::type_complexity)] +fn setup_sum_check( + log_size: usize, +) -> ( + E, + E, + Vec, + ( + MultiLinearPoly, + MultiLinearPoly, + MultiLinearPoly, + MultiLinearPoly, + MultiLinearPoly, + ), +) { + let n = 1 << log_size; + let table = MultiLinearPoly::from_evaluations(rand_vector(n)); + let multiplicity = MultiLinearPoly::from_evaluations(rand_vector(n)); + let values_0 = MultiLinearPoly::from_evaluations(rand_vector(n)); + let values_1 = MultiLinearPoly::from_evaluations(rand_vector(n)); + let values_2 = MultiLinearPoly::from_evaluations(rand_vector(n)); + + // this will not generate the correct claim with overwhelming probability but should be fine + // for benchmarking + let rand_pt: Vec = rand_vector(log_size + 2); + let r_batch: E = rand_value(); + let claim: E = rand_value(); + + (claim, r_batch, rand_pt, (table, multiplicity, values_0, values_1, values_2)) +} + +#[derive(Clone, Default)] +pub struct PlainLogUpGkrEval { + _field: PhantomData, +} + +impl LogUpGkrEvaluator for PlainLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> Vec> { + let committed_0 = LogUpGkrOracle::CurrentRow(0); + let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_2 = LogUpGkrOracle::CurrentRow(2); + let committed_3 = LogUpGkrOracle::CurrentRow(3); + let committed_4 = LogUpGkrOracle::CurrentRow(4); + vec![committed_0, committed_1, committed_2, committed_3, committed_4] + } + + fn get_num_rand_values(&self) -> usize { + 1 + } + + fn get_num_fractions(&self) -> usize { + 4 + } + + fn max_degree(&self) -> usize { + 3 + } + + fn build_query(&self, frame: &EvaluationFrame, _periodic_values: &[E], query: &mut [E]) + where + E: FieldElement, + { + query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f); + } + + fn evaluate_query( + &self, + query: &[F], + rand_values: &[E], + numerator: &mut [E], + denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + assert_eq!(numerator.len(), 4); + assert_eq!(denominator.len(), 4); + assert_eq!(query.len(), 5); + numerator[0] = E::from(query[1]); + numerator[1] = E::ONE; + numerator[2] = E::ONE; + numerator[3] = E::ONE; + + denominator[0] = rand_values[0] - E::from(query[0]); + denominator[1] = -(rand_values[0] - E::from(query[2])); + denominator[2] = -(rand_values[0] - E::from(query[3])); + denominator[3] = -(rand_values[0] - E::from(query[4])); + } +} + +criterion_group!(group, sum_check_high_degree); +criterion_main!(group); diff --git a/sumcheck/benches/sum_check_plain.rs b/sumcheck/benches/sum_check_plain.rs new file mode 100644 index 000000000..14fd859ce --- /dev/null +++ b/sumcheck/benches/sum_check_plain.rs @@ -0,0 +1,66 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::time::Duration; + +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use crypto::{hashers::Blake3_192, DefaultRandomCoin, RandomCoin}; +use math::{fields::f64::BaseElement, FieldElement}; +use rand_utils::{rand_value, rand_vector}; +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; +use winter_sumcheck::{sumcheck_prove_plain, EqFunction, MultiLinearPoly}; + +const LOG_POLY_SIZE: [usize; 2] = [18, 20]; + +fn sum_check_plain(c: &mut Criterion) { + let mut group = c.benchmark_group("Sum-check prover plain"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + for &log_poly_size in LOG_POLY_SIZE.iter() { + group.bench_function(BenchmarkId::new("", log_poly_size), |b| { + b.iter_batched( + || { + let transcript = + DefaultRandomCoin::>::new(&[BaseElement::ZERO; 4]); + (setup_sum_check::(log_poly_size), transcript) + }, + |((claim, r_batch, p, q, eq), transcript)| { + let mut eq = eq; + let mut transcript = transcript; + + sumcheck_prove_plain(claim, r_batch, p, q, &mut eq, &mut transcript) + }, + BatchSize::SmallInput, + ) + }); + } +} + +#[allow(clippy::too_many_arguments)] +#[allow(clippy::type_complexity)] +fn setup_sum_check( + log_size: usize, +) -> (E, E, MultiLinearPoly, MultiLinearPoly, MultiLinearPoly) { + let n = 1 << (log_size + 1); + let p: Vec = rand_vector(n); + let q: Vec = rand_vector(n); + + // this will not generate the correct claim with overwhelming probability but should be fine + // for benchmarking + let rand_pt = rand_vector(log_size); + let r_batch: E = rand_value(); + let claim: E = rand_value(); + + let p = MultiLinearPoly::from_evaluations(p); + let q = MultiLinearPoly::from_evaluations(q); + let eq = MultiLinearPoly::from_evaluations(EqFunction::new(rand_pt.into()).evaluations()); + + (claim, r_batch, p, q, eq) +} + +criterion_group!(group, sum_check_plain); +criterion_main!(group); diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs new file mode 100644 index 000000000..f30db974c --- /dev/null +++ b/sumcheck/src/lib.rs @@ -0,0 +1,280 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#![no_std] + +use alloc::vec::Vec; + +use ::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +use math::FieldElement; + +#[macro_use] +extern crate alloc; + +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; + +mod prover; +pub use prover::*; + +mod verifier; +pub use verifier::*; + +mod univariate; +pub use univariate::{CompressedUnivariatePoly, CompressedUnivariatePolyEvals}; + +mod multilinear; +pub use multilinear::{EqFunction, MultiLinearPoly}; + +/// Represents an opening claim at an evaluation point against a batch of oracles. +/// +/// After verifying [`Proof`], the verifier is left with a question on the validity of a final +/// claim on a number of oracles open to a given set of values at some given point. +/// This question is answered either using further interaction with the Prover or using +/// a polynomial commitment opening proof in the compiled protocol. +#[derive(Clone, Debug)] +pub struct FinalOpeningClaim { + pub eval_point: Vec, + pub openings: Vec, +} + +impl Serializable for FinalOpeningClaim { + fn write_into(&self, target: &mut W) { + let Self { eval_point, openings } = self; + eval_point.write_into(target); + openings.write_into(target); + } +} + +impl Deserializable for FinalOpeningClaim +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + eval_point: Deserializable::read_from(source)?, + openings: Deserializable::read_from(source)?, + }) + } +} + +/// A sum-check proof. +/// +/// Composed of the round proofs i.e., the polynomials sent by the Prover at each round as well as +/// the (claimed) openings of the multi-linear oracles at the evaluation point given by the round +/// challenges. +#[derive(Debug, Clone)] +pub struct SumCheckProof { + pub openings_claim: FinalOpeningClaim, + pub round_proofs: Vec>, +} + +impl Serializable for SumCheckProof +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + self.openings_claim.write_into(target); + self.round_proofs.write_into(target); + } +} + +impl Deserializable for SumCheckProof +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + openings_claim: Deserializable::read_from(source)?, + round_proofs: Deserializable::read_from(source)?, + }) + } +} + +/// A sum-check round proof. +/// +/// This represents the partial polynomial sent by the Prover during one of the rounds of the +/// sum-check protocol. The polynomial is in coefficient form and excludes the coefficient for +/// the linear term as the Verifier can recover it from the other coefficients and the current +/// (reduced) claim. +#[derive(Debug, Clone)] +pub struct RoundProof { + pub round_poly_coefs: CompressedUnivariatePoly, +} + +impl Serializable for RoundProof { + fn write_into(&self, target: &mut W) { + let Self { round_poly_coefs } = self; + round_poly_coefs.write_into(target); + } +} + +impl Deserializable for RoundProof +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + round_poly_coefs: Deserializable::read_from(source)?, + }) + } +} + +/// A proof for the input circuit layer i.e., the final layer in the GKR protocol. +#[derive(Debug, Clone)] +pub struct FinalLayerProof { + pub proof: SumCheckProof, +} + +impl Serializable for FinalLayerProof +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + let Self { proof } = self; + proof.write_into(target); + } +} + +impl Deserializable for FinalLayerProof +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + proof: Deserializable::read_from(source)?, + }) + } +} + +/// Contains the round challenges sent by the Verifier up to some round as well as the current +/// reduced claim. +#[derive(Debug)] +pub struct SumCheckRoundClaim { + pub eval_point: Vec, + pub claim: E, +} + +// GKR CIRCUIT PROOF +// =============================================================================================== + +/// A GKR proof for the correct evaluation of the sum of fractions circuit. +#[derive(Debug, Clone)] +pub struct GkrCircuitProof { + pub circuit_outputs: CircuitOutput, + pub before_final_layer_proofs: BeforeFinalLayerProof, + pub final_layer_proof: FinalLayerProof, +} + +impl GkrCircuitProof { + pub fn get_final_opening_claim(&self) -> FinalOpeningClaim { + self.final_layer_proof.proof.openings_claim.clone() + } +} + +impl Serializable for GkrCircuitProof +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + self.circuit_outputs.write_into(target); + self.before_final_layer_proofs.write_into(target); + self.final_layer_proof.proof.write_into(target); + } +} + +impl Deserializable for GkrCircuitProof +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + circuit_outputs: CircuitOutput::read_from(source)?, + before_final_layer_proofs: BeforeFinalLayerProof::read_from(source)?, + final_layer_proof: FinalLayerProof::read_from(source)?, + }) + } +} + +/// A set of sum-check proofs for all GKR layers but for the input circuit layer. +#[derive(Debug, Clone)] +pub struct BeforeFinalLayerProof { + pub proof: Vec>, +} + +impl Serializable for BeforeFinalLayerProof +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + let Self { proof } = self; + proof.write_into(target); + } +} + +impl Deserializable for BeforeFinalLayerProof +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + proof: Deserializable::read_from(source)?, + }) + } +} + +/// Holds the output layer of an [`EvaluatedCircuit`]. +#[derive(Clone, Debug)] +pub struct CircuitOutput { + pub numerators: MultiLinearPoly, + pub denominators: MultiLinearPoly, +} + +impl Serializable for CircuitOutput +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + let Self { numerators, denominators } = self; + numerators.write_into(target); + denominators.write_into(target); + } +} + +impl Deserializable for CircuitOutput +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + numerators: MultiLinearPoly::read_from(source)?, + denominators: MultiLinearPoly::read_from(source)?, + }) + } +} + +/// The non-linear composition polynomial of the LogUp-GKR protocol. +/// +/// This is the result of batching the `p_k` and `q_k` of section 3.2 in +/// https://eprint.iacr.org/2023/1284.pdf. +#[inline(always)] +fn comb_func(p0: E, p1: E, q0: E, q1: E, eq: E, r_batch: E) -> E { + (p0 * q1 + p1 * q0 + r_batch * q0 * q1) * eq +} + +/// The non-linear composition polynomial of the LogUp-GKR protocol specific to the input layer. +pub fn evaluate_composition_poly( + eq_at_mu: &[E], + numerators: &[E], + denominators: &[E], + eq_eval: E, + r_sum_check: E, +) -> E { + numerators + .chunks(2) + .zip(denominators.chunks(2).zip(eq_at_mu.iter())) + .map(|(p, (q, eq_w))| *eq_w * comb_func(p[0], p[1], q[0], q[1], eq_eval, r_sum_check)) + .fold(E::ZERO, |acc, x| acc + x) +} diff --git a/sumcheck/src/multilinear.rs b/sumcheck/src/multilinear.rs new file mode 100644 index 000000000..110ef1fa7 --- /dev/null +++ b/sumcheck/src/multilinear.rs @@ -0,0 +1,381 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use alloc::vec::Vec; +use core::ops::Index; + +use math::FieldElement; +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; +use smallvec::SmallVec; +use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; + +// MULTI-LINEAR POLYNOMIAL +// ================================================================================================ + +/// Represents a multi-linear polynomial. +/// +/// The representation stores the evaluations of the polynomial over the boolean hyper-cube +/// ${0 , 1}^{\nu}$. +#[derive(Clone, Debug, PartialEq)] +pub struct MultiLinearPoly { + evaluations: Vec, +} + +impl MultiLinearPoly { + /// Constructs a [`MultiLinearPoly`] from its evaluations over the boolean hyper-cube ${0 , 1}^{\nu}$. + pub fn from_evaluations(evaluations: Vec) -> Self { + assert!(evaluations.len().is_power_of_two(), "A multi-linear polynomial should have a power of 2 number of evaluations over the Boolean hyper-cube"); + Self { evaluations } + } + + /// Returns the number of variables of the multi-linear polynomial. + pub fn num_variables(&self) -> usize { + self.evaluations.len().trailing_zeros() as usize + } + + /// Returns the evaluations over the boolean hyper-cube. + pub fn evaluations(&self) -> &[E] { + &self.evaluations + } + + /// Returns the number of evaluations. This is equal to the size of the boolean hyper-cube. + pub fn num_evaluations(&self) -> usize { + self.evaluations.len() + } + + /// Evaluate the multi-linear at some query $(r_0, ..., r_{{\nu} - 1}) \in \mathbb{F}^{\nu}$. + /// + /// It first computes the evaluations of the Lagrange basis polynomials over the interpolating + /// set ${0 , 1}^{\nu}$ at $(r_0, ..., r_{{\nu} - 1})$ i.e., the Lagrange kernel at $(r_0, ..., r_{{\nu} - 1})$. + /// The evaluation then is the inner product, indexed by ${0 , 1}^{\nu}$, of the vector of + /// evaluations times the Lagrange kernel. + pub fn evaluate(&self, query: &[E]) -> E { + let tensored_query = compute_lagrange_basis_evals_at(query); + inner_product(&self.evaluations, &tensored_query) + } + + /// Similar to [`Self::evaluate`], except that the query was already turned into the Lagrange + /// kernel (i.e. the [`lagrange_ker::EqFunction`] evaluated at every point in the set + /// `${0 , 1}^{\nu}$`). + /// + /// This is more efficient than [`Self::evaluate`] when multiple different [`MultiLinearPoly`] + /// need to be evaluated at the same query point. + pub fn evaluate_with_lagrange_kernel(&self, lagrange_kernel: &[E]) -> E { + inner_product(&self.evaluations, lagrange_kernel) + } + + /// Computes $f(r_0, y_1, ..., y_{{\nu} - 1})$ using the linear interpolation formula + /// $(1 - r_0) * f(0, y_1, ..., y_{{\nu} - 1}) + r_0 * f(1, y_1, ..., y_{{\nu} - 1})$ and assigns + /// the resulting multi-linear, defined over a domain of half the size, to `self`. + pub fn bind_least_significant_variable(&mut self, round_challenge: E) { + let num_evals = self.evaluations.len() >> 1; + #[cfg(not(feature = "concurrent"))] + { + for i in 0..num_evals { + self.evaluations[i] = self.evaluations[i << 1] + + round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]); + } + self.evaluations.truncate(num_evals) + } + + #[cfg(feature = "concurrent")] + { + let mut result = unsafe { utils::uninit_vector(num_evals) }; + result.par_iter_mut().enumerate().for_each(|(i, ev)| { + *ev = self.evaluations[i << 1] + + round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]) + }); + self.evaluations = result + } + } + + /// Given the multilinear polynomial $f(y_0, y_1, ..., y_{{\nu} - 1})$, returns two polynomials: + /// $f(0, y_1, ..., y_{{\nu} - 1})$ and $f(1, y_1, ..., y_{{\nu} - 1})$. + pub fn project_least_significant_variable(mut self) -> (Self, Self) { + let odds: Vec = self + .evaluations + .iter() + .enumerate() + .filter_map(|(idx, x)| if idx % 2 == 1 { Some(*x) } else { None }) + .collect(); + + // Builds the evens multilinear from the current `self.evaluations` buffer, which saves an + // allocation. + let evens = { + let evens_size = self.num_evaluations() / 2; + for write_idx in 0..evens_size { + let read_idx = write_idx * 2; + self.evaluations[write_idx] = self.evaluations[read_idx]; + } + self.evaluations.truncate(evens_size); + + self.evaluations + }; + + (Self::from_evaluations(evens), Self::from_evaluations(odds)) + } +} + +impl Index for MultiLinearPoly { + type Output = E; + + fn index(&self, index: usize) -> &E { + &(self.evaluations[index]) + } +} + +impl Serializable for MultiLinearPoly +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + let Self { evaluations } = self; + evaluations.write_into(target); + } +} + +impl Deserializable for MultiLinearPoly +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + evaluations: Deserializable::read_from(source)?, + }) + } +} + +// EQ FUNCTION +// ================================================================================================ + +/// Maximal expected size of the point of a given Lagrange kernel. +const MAX_EQ_SIZE: usize = 25; + +/// The EQ (equality) function is the binary function defined by +/// +/// $$ +/// EQ: {0 , 1}^{\nu} ⛌ {0 , 1}^{\nu} \longrightarrow {0 , 1} +/// ((x_0, ..., x_{{\nu} - 1}), (y_0, ..., y_{{\nu} - 1})) \mapsto \prod_{i = 0}^{{\nu} - 1} (x_i \cdot y_i + (1 - x_i) +/// \cdot (1 - y_i)) +/// $$ +/// +/// Taking its multi-linear extension $\tilde{EQ}$, we can define a basis for the set of multi-linear +/// polynomials in {\nu} variables by +/// $${\tilde{EQ}(., (y_0, ..., y_{{\nu} - 1})): (y_0, ..., y_{{\nu} - 1}) \in {0 , 1}^{\nu}}$$ +/// where each basis function is a function of its first argument. This is called the Lagrange or +/// evaluation basis for evaluation set ${0 , 1}^{\nu}$. +/// +/// Given a function $(f: {0 , 1}^{\nu} \longrightarrow \mathbb{F})$, its multi-linear extension (i.e., the unique +/// mult-linear polynomial extending `f` to $(\tilde{f}: \mathbb{F}^{\nu} \longrightarrow \mathbb{F})$ and agreeing with it on ${0 , 1}^{\nu}$) is +/// defined as the summation of the evaluations of f against the Lagrange basis. +/// More specifically, given $(r_0, ..., r_{{\nu} - 1}) \in \mathbb{F}^{\nu}$, then: +/// +/// $$ +/// \tilde{f}(r_0, ..., r_{{\nu} - 1}) = \sum_{(y_0, ..., y_{{\nu} - 1}) \in {0 , 1}^{\nu}} +/// f(y_0, ..., y_{{\nu} - 1}) \tilde{EQ}((r_0, ..., r_{{\nu} - 1}), (y_0, ..., y_{{\nu} - 1})) +/// $$ +/// +/// We call the Lagrange kernel the evaluation of the $\tilde{EQ}$ function at +/// $((r_0, ..., r_{{\nu} - 1}), (y_0, ..., y_{{\nu} - 1}))$ for all $(y_0, ..., y_{{\nu} - 1}) \in {0 , 1}^{\nu}$ for +/// a fixed $(r_0, ..., r_{{\nu} - 1}) \in \mathbb{F}^{\nu}$. +/// +/// [`EqFunction`] represents $\tilde{EQ}$ the multi-linear extension of +/// +/// $((y_0, ..., y_{{\nu} - 1}) \mapsto EQ((r_0, ..., r_{{\nu} - 1}), (y_0, ..., y_{{\nu} - 1})))$ +/// +/// and contains a method to generate the Lagrange kernel for defining evaluations of multi-linear +/// extensions of arbitrary functions $(f: {0 , 1}^{\nu} \longrightarrow \mathbb{F})$ at a given point $(r_0, ..., r_{{\nu} - 1})$ +/// as well as a method to evaluate $\tilde{EQ}((r_0, ..., r_{{\nu} - 1}), (t_0, ..., t_{{\nu} - 1})))$ for +/// $(t_0, ..., t_{{\nu} - 1}) \in \mathbb{F}^{\nu}$. +pub struct EqFunction { + r: SmallVec<[E; MAX_EQ_SIZE]>, +} + +impl EqFunction { + /// Creates a new [EqFunction]. + pub fn new(r: SmallVec<[E; MAX_EQ_SIZE]>) -> Self { + EqFunction { r } + } + + /// Computes $\tilde{EQ}((r_0, ..., r_{{\nu} - 1}), (t_0, ..., t_{{\nu} - 1})))$. + pub fn evaluate(&self, t: &[E]) -> E { + assert_eq!(self.r.len(), t.len()); + + (0..self.r.len()) + .map(|i| self.r[i] * t[i] + (E::ONE - self.r[i]) * (E::ONE - t[i])) + .fold(E::ONE, |acc, term| acc * term) + } + + /// Computes $\tilde{EQ}((r_0, ..., r_{{\nu} - 1}), (y_0, ..., y_{{\nu} - 1}))$ for all + /// $(y_0, ..., y_{{\nu} - 1}) \in {0 , 1}^{\nu}$ i.e., the Lagrange kernel at $r = (r_0, ..., r_{{\nu} - 1})$. + pub fn evaluations(&self) -> Vec { + compute_lagrange_basis_evals_at(&self.r) + } + + /// Returns the evaluations of + /// $((y_0, ..., y_{{\nu} - 1}) \mapsto \tilde{EQ}((r_0, ..., r_{{\nu} - 1}), (y_0, ..., y_{{\nu} - 1})))$ + /// over ${0 , 1}^{\nu}$. + pub fn ml_at(evaluation_point: SmallVec<[E; MAX_EQ_SIZE]>) -> MultiLinearPoly { + let eq_evals = EqFunction::new(evaluation_point).evaluations(); + MultiLinearPoly::from_evaluations(eq_evals) + } +} + +// HELPER +// ================================================================================================ + +/// Computes the evaluations of the Lagrange basis polynomials over the interpolating +/// set ${0 , 1}^{\nu}$ at $(r_0, ..., r_{{\nu} - 1})$ i.e., the Lagrange kernel at $(r_0, ..., r_{{\nu} - 1})$. +/// +/// If `concurrent` feature is enabled, this function can make use of multi-threading. +/// +/// The implementation uses the memoization technique in Lemma 3.8 in [1]. More precisely, we can +/// build a table $A^{(\nu)}$ in $\nu$ steps using the following master equation: +/// +/// $$ +/// A^{(j)}\left[\left(w_{1}, \dots, w_{j} \right)\right] = +/// A^{(j - 1)}\left[\left(w_{1}, \dots, w_{j - 1} \right)\right] \times +/// \left(w_{j}\cdot r_{j} + (1 - w_{j})\cdot( 1 - r_{j}) \right) +/// $$ +/// +/// if we interpret $\left(w_{1}, \dots, w_{j} \right)$ in little endian i.e., +/// $\left(w_{1}, \dots, w_{j} \right) = \sum_{i=1}^{\nu} 2^{i - 1}\cdot w_{i}$. +/// +/// We thus have the following algorithm: +/// +/// 1. Split current table, stored as a vector, $A^{(j)}\left[\left(w_{1}, \dots, w_{j} \right)\right]$ +/// into two tables $A^{(j)}\left[\left(w_{1}, \dots, w_{j-1}, 0 \right)\right]$ and +/// $A^{(j)}\left[\left(w_{1}, \dots, w_{j-1}, 1 \right)\right]$, +/// with the first part initialized to $A^{(j - 1)}\left[\left(w_{1}, \dots, w_{j-1} \right)\right]$. +/// 2. Iterating over $\left(w_{1}, \dots, w_{j-1} \right)$, do: +/// 1. Let $factor = A^{(j - 1)}\left[\left(w_{1}, \dots, w_{j-1} \right)\right]$, which is equal +/// by the above to $A^{(j)}\left[\left(w_{1}, \dots, w_{j-1}, 0 \right)\right]$. +/// 2. $A^{(j)}\left[\left(w_{1}, \dots, w_{j-1}, 1 \right)\right] = factor \cdot r_j$ +/// 3. $A^{(j)}\left[\left(w_{1}, \dots, w_{j-1}, 0 \right)\right] = +/// A^{(j)}\left[\left(w_{1}, \dots, w_{j-1}, 0 \right)\right] - +/// A^{(j)}\left[\left(w_{1}, \dots, w_{j-1}, 1 \right)\right]$ +/// +/// Note that we can allocate from the start a vector of size $2^{\nu}$ in order to hold the final +/// as well as the intermediate tables. +/// +/// [1]: https://people.cs.georgetown.edu/jthaler/ProofsArgsAndZK.pdf +fn compute_lagrange_basis_evals_at(query: &[E]) -> Vec { + let n = 1 << query.len(); + let mut evals = unsafe { utils::uninit_vector(n) }; + + let mut size = 1; + evals[0] = E::ONE; + #[cfg(not(feature = "concurrent"))] + let evals = { + for r_i in query.iter() { + let (left_evals, right_evals) = evals.split_at_mut(size); + left_evals.iter_mut().zip(right_evals.iter_mut()).for_each(|(left, right)| { + let factor = *left; + *right = factor * *r_i; + *left -= *right; + }); + + size <<= 1; + } + evals + }; + + #[cfg(feature = "concurrent")] + let evals = { + for r_i in query.iter() { + let (left_evals, right_evals) = evals.split_at_mut(size); + left_evals + .par_iter_mut() + .zip(right_evals.par_iter_mut()) + .for_each(|(left, right)| { + let factor = *left; + *right = factor * *r_i; + *left -= *right; + }); + + size <<= 1; + } + evals + }; + + evals +} + +/// Computes the inner product in the extension field of two slices with the same number of items. +/// +/// If `concurrent` feature is enabled, this function can make use of multi-threading. +pub fn inner_product(x: &[E], y: &[E]) -> E { + #[cfg(not(feature = "concurrent"))] + return x.iter().zip(y.iter()).fold(E::ZERO, |acc, (x_i, y_i)| acc + *x_i * *y_i); + + #[cfg(feature = "concurrent")] + return x + .par_iter() + .zip(y.par_iter()) + .map(|(x_i, y_i)| *x_i * *y_i) + .reduce(|| E::ZERO, |a, b| a + b); +} + +// TESTS +// ================================================================================================ + +#[test] +fn multi_linear_sanity_checks() { + use math::fields::f64::BaseElement; + let nu = 3; + let n = 1 << nu; + + // the zero multi-linear should evaluate to zero + let p = MultiLinearPoly::from_evaluations(vec![BaseElement::ZERO; n]); + let challenge: Vec = rand_utils::rand_vector(nu); + + assert_eq!(BaseElement::ZERO, p.evaluate(&challenge)); + + // the constant multi-linear should be constant everywhere + let constant = rand_utils::rand_value(); + let p = MultiLinearPoly::from_evaluations(vec![constant; n]); + let challenge: Vec = rand_utils::rand_vector(nu); + + assert_eq!(constant, p.evaluate(&challenge)) +} + +#[test] +fn test_bind() { + use math::fields::f64::BaseElement; + let mut p = MultiLinearPoly::from_evaluations(vec![BaseElement::ONE; 8]); + let expected = MultiLinearPoly::from_evaluations(vec![BaseElement::ONE; 4]); + + let challenge = rand_utils::rand_value(); + p.bind_least_significant_variable(challenge); + assert_eq!(p, expected) +} + +#[test] +fn test_eq_function() { + use math::fields::f64::BaseElement; + use rand_utils::rand_value; + use smallvec::smallvec; + + let one = BaseElement::ONE; + + // Lagrange kernel is computed correctly + let r0 = rand_value(); + let r1 = rand_value(); + let eq_function = EqFunction::new(smallvec![r0, r1]); + + let expected = vec![(one - r0) * (one - r1), r0 * (one - r1), (one - r0) * r1, r0 * r1]; + + assert_eq!(expected, eq_function.evaluations()); + + // Lagrange kernel evaluation is correct + let q0 = rand_value(); + let q1 = rand_value(); + let tensored_query = vec![(one - q0) * (one - q1), q0 * (one - q1), (one - q0) * q1, q0 * q1]; + + let expected = inner_product(&tensored_query, &eq_function.evaluations()); + + assert_eq!(expected, eq_function.evaluate(&[q0, q1])) +} diff --git a/sumcheck/src/prover/error.rs b/sumcheck/src/prover/error.rs new file mode 100644 index 000000000..c86198d73 --- /dev/null +++ b/sumcheck/src/prover/error.rs @@ -0,0 +1,15 @@ +#[derive(Debug, thiserror::Error)] +pub enum SumCheckProverError { + #[error("number of rounds for sum-check must be greater than zero")] + NumRoundsZero, + #[error("the number of rounds is greater than the number of variables")] + TooManyRounds, + #[error("should provide at least one multi-linear polynomial as input")] + NoMlsProvided, + #[error("failed to generate round challenge")] + FailedToGenerateChallenge, + #[error("the provided multi-linears have different arities")] + MlesDifferentArities, + #[error("multi-linears should have at least one variable")] + AtLeastOneVariable, +} diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs new file mode 100644 index 000000000..a96adee4c --- /dev/null +++ b/sumcheck/src/prover/high_degree.rs @@ -0,0 +1,481 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use alloc::vec::Vec; + +use air::LogUpGkrEvaluator; +use crypto::{ElementHasher, RandomCoin}; +use math::FieldElement; +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; + +use super::SumCheckProverError; +use crate::{ + evaluate_composition_poly, CompressedUnivariatePolyEvals, EqFunction, FinalOpeningClaim, + MultiLinearPoly, RoundProof, SumCheckProof, SumCheckRoundClaim, +}; + +/// A sum-check prover for the input layer which can accommodate non-linear expressions in +/// the numerators of the LogUp relation. +/// +/// The LogUp-GKR protocol in [1] is an IOP for the following statement +/// +/// $$ +/// \sum_{v_i, x_i} \frac{p_n\left(v_1, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right)} +/// {q_n\left(v_1, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right)} = C +/// $$ +/// +/// where: +/// +/// $$ +/// p_n\left(v_1, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = +/// \sum_{w\in\{0, 1\}^\mu} EQ\left(\left(v_1, \cdots, v_{\mu}\right), +/// \left(w_1, \cdots, w_{\mu}\right)\right) +/// g_{[w]}\left(f_1\left(x_1, \cdots, x_{\nu}\right), +/// \cdots, f_l\left(x_1, \cdots, x_{\nu}\right)\right) +/// $$ +/// +/// and +/// +/// $$ +/// q_n\left(v_1, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = +/// \sum_{w\in\{0, 1\}^\mu} EQ\left(\left(v_1, \cdots, v_{\mu}\right), +/// \left(w_1, \cdots, w_{\mu}\right)\right) +/// h_{[w]}\left(f_1\left(x_1, \cdots, x_{\nu}\right), +/// \cdots, f_l\left(x_1, \cdots, x_{\nu}\right)\right) +/// $$ +/// +/// and +/// +/// 1. $f_i$ are multi-linears. +/// 2. ${[w]} := \sum_i w_i \cdot 2^i$ and $w := (w_1, \cdots, w_{\mu})$. +/// 3. $h_{j}$ and $g_{j}$ are multi-variate polynomials for $j = 0, \cdots, 2^{\mu} - 1$. +/// 4. $n := \nu + \mu$ +/// 5. $\mathbb{B}_{\gamma} := \{0, 1\}^{\gamma}$ for positive integer $\gamma$. +/// +/// The sum above is evaluated using a layered circuit with the equation linking the input layer +/// values $p_n$ to the next layer values $p_{n-1}$ given by the following relations +/// +/// $$ +/// p_{n-1}\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = \sum_{w_i, y_i} +/// EQ\left(\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right), +/// \left(w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right)\right) +/// \cdot \left( p_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) +/// \cdot q_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) + +/// p_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) \cdot +/// q_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right)\right) +/// $$ +/// +/// and +/// +/// $$ +/// q_{n-1}\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = \sum_{w_i, y_i} +/// EQ\left(\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right), +/// \left(w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right)\right) +/// \cdot \left( q_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) +/// \cdot q_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right)\right) +/// $$ +/// +/// and similarly for all subsequent layers. +/// +/// By the properties of the $EQ$ function, we can write the above as follows: +/// +/// $$ +/// p_{n-1}\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = \sum_{y_i} +/// EQ\left(\left(x_1, \cdots, x_{\nu}\right), +/// \left(y_1, \cdots, y_{\nu}\right)\right) +/// \left( \sum_{w_i} EQ\left(\left(v_2, \cdots, v_{\mu}\right), +/// \left(w_2, \cdots, w_{\mu}\right)\right) +/// \cdot \left( p_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) +/// \cdot q_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) + +/// p_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) \cdot +/// q_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right)\right) \right) +/// $$ +/// +/// and +/// +/// $$ +/// q_{n-1}\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = \sum_{y_i} +/// EQ\left(\left(x_1, \cdots, x_{\nu}\right), +/// \left(y_1, \cdots, y_{\nu}\right)\right) +/// \left( \sum_{w_i} EQ\left(\left(v_2, \cdots, v_{\mu}\right)\right) +/// \cdot q_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) +/// \cdot q_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) \right) +/// $$ +/// +/// These expressions are nothing but the equations in Section 3.2 in [1] but with the projection +/// happening in the first argument instead of the last one. +/// The current function is then tasked with running a batched sum-check protocol for +/// +/// $$ +/// p_{n-1}\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = +/// \sum_{y\in\mathbb{B}_{\nu}} G(y_{1}, ..., y_{\nu}) +/// $$ +/// +/// and +/// +/// $$ +/// q_{n-1}\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = +/// \sum_{y\in\mathbb{B}_{\nu}} H\left(y_1, \cdots, y_{\nu} \right) +/// $$ +/// +/// where +/// +/// $$ +/// G := \left( \left(y_1, \cdots, y_{\nu}\right) \longrightarrow +/// EQ\left(\left(x_1, \cdots, x_{\nu}\right), +/// \left(y_1, \cdots, y_{\nu}\right)\right) +/// \left( \sum_{w_i} EQ\left(\left(v_2, \cdots, v_{\mu}\right), +/// \left(w_2, \cdots, w_{\mu}\right)\right) +/// \cdot \left( p_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) +/// \cdot q_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) + +/// p_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) \cdot +/// q_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right)\right) \right) +/// \right) +/// $$ +/// +/// and +/// +/// $$ +/// H := \left( \left(y_1, \cdots, y_{\nu}\right) \longrightarrow +/// EQ\left(\left(x_1, \cdots, x_{\nu}\right), +/// \left(y_1, \cdots, y_{\nu}\right)\right) +/// \left( \sum_{w_i} EQ\left(\left(v_2, \cdots, v_{\mu}\right)\right) +/// \cdot q_n\left(1, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) +/// \cdot q_n\left(0, w_2, \cdots, w_{\mu}, y_1, \cdots, y_{\nu}\right) \right) +/// \right) +/// $$ +/// +/// [1]: https://eprint.iacr.org/2023/1284 +#[allow(clippy::too_many_arguments)] +pub fn sum_check_prove_higher_degree< + E: FieldElement, + H: ElementHasher, +>( + evaluator: &impl LogUpGkrEvaluator::BaseField>, + evaluation_point: Vec, + claim: E, + r_sum_check: E, + log_up_randomness: Vec, + mut mls: Vec>, + coin: &mut impl RandomCoin, +) -> Result, SumCheckProverError> { + let num_rounds = mls[0].num_variables(); + + let mut round_proofs = vec![]; + + // split the evaluation point into two points of dimension mu and nu, respectively + let mu = evaluator.get_num_fractions().trailing_zeros() - 1; + let (evaluation_point_mu, evaluation_point_nu) = evaluation_point.split_at(mu as usize); + let eq_mu = EqFunction::ml_at(evaluation_point_mu.into()).evaluations().to_vec(); + let mut eq_nu = EqFunction::ml_at(evaluation_point_nu.into()); + + // setup first round claim + let mut current_round_claim = SumCheckRoundClaim { eval_point: vec![], claim }; + + // run the first round of the protocol + let round_poly_evals = + sumcheck_round(&eq_mu, evaluator, &eq_nu, &mls, &log_up_randomness, r_sum_check); + let round_poly_coefs = round_poly_evals.to_poly(current_round_claim.claim); + + // reseed with the s_0 polynomial + coin.reseed(H::hash_elements(&round_poly_coefs.0)); + round_proofs.push(RoundProof { round_poly_coefs }); + + for i in 1..num_rounds { + // generate random challenge r_i for the i-th round + let round_challenge = + coin.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + + // compute the new reduced round claim + let new_round_claim = + reduce_claim(&round_proofs[i - 1], current_round_claim, round_challenge); + + // fold each multi-linear using the round challenge + mls.iter_mut() + .for_each(|ml| ml.bind_least_significant_variable(round_challenge)); + eq_nu.bind_least_significant_variable(round_challenge); + + // run the i-th round of the protocol using the folded multi-linears for the new reduced + // claim. This basically computes the s_i polynomial. + let round_poly_evals = + sumcheck_round(&eq_mu, evaluator, &eq_nu, &mls, &log_up_randomness, r_sum_check); + + // update the claim + current_round_claim = new_round_claim; + + let round_poly_coefs = round_poly_evals.to_poly(current_round_claim.claim); + + // reseed with the s_i polynomial + coin.reseed(H::hash_elements(&round_poly_coefs.0)); + let round_proof = RoundProof { round_poly_coefs }; + round_proofs.push(round_proof); + } + + // generate the last random challenge + let round_challenge = + coin.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + + // fold each multi-linear using the last random round challenge + mls.iter_mut() + .for_each(|ml| ml.bind_least_significant_variable(round_challenge)); + eq_nu.bind_least_significant_variable(round_challenge); + + let SumCheckRoundClaim { eval_point, claim: _claim } = + reduce_claim(&round_proofs[num_rounds - 1], current_round_claim, round_challenge); + + let openings = mls.iter_mut().map(|ml| ml.evaluations()[0]).collect(); + + Ok(SumCheckProof { + openings_claim: FinalOpeningClaim { eval_point, openings }, + round_proofs, + }) +} + +/// Computes the polynomial +/// +/// $$ +/// s_i(X_i) := \sum_{(x_{i + 1},\cdots, x_{\nu - 1}) +/// w(r_0,\cdots, r_{i - 1}, X_i, x_{i + 1}, \cdots, x_{\nu - 1}). +/// $$ +/// +/// where +/// +/// $$ +/// w(x_0,\cdots, x_{\nu - 1}) := g(f_0((x_0,\cdots, x_{\nu - 1})), +/// \cdots , f_c((x_0,\cdots, x_{\nu - 1}))). +/// $$ +/// +/// where `g` is the expression defined in the documentation of [`sum_check_prove_higher_degree`] +/// +/// Given a degree bound `d_max` for all variables, it suffices to compute the evaluations of `s_i` +/// at `d_max + 1` points. Given that $s_{i}(0) = s_{i}(1) - s_{i - 1}(r_{i - 1})$ it is sufficient +/// to compute the evaluations on only `d_max` points. +/// +/// The algorithm works by iterating over the variables $(x_{i + 1}, \cdots, x_{\nu - 1})$ in +/// ${0, 1}^{\nu - 1 - i}$. For each such tuple, we store the evaluations of the (folded) +/// multi-linears at $(0, x_{i + 1}, \cdots, x_{\nu - 1})$ and +/// $(1, x_{i + 1}, \cdots, x_{\nu - 1})$ in two arrays, `evals_zero` and `evals_one`. +/// Using `evals_one`, remember that we optimize evaluating at 0 away, we get the first evaluation +/// i.e., $s_i(1)$. +/// +/// For the remaining evaluations, we use the fact that the folded `f_i` is multi-linear and hence +/// we can write +/// +/// $$ +/// f_i(X_i, x_{i + 1}, \cdots, x_{\nu - 1}) = +/// (1 - X_i) . f_i(0, x_{i + 1}, \cdots, x_{\nu - 1}) + +/// X_i . f_i(1, x_{i + 1}, \cdots, x_{\nu - 1}) +/// $$ +/// +/// Note that we omitted writing the folding randomness for readability. +/// Since the evaluation domain is $\{0, 1, ... , d_max\}$, we can compute the evaluations based on +/// the previous one using only additions. This is the purpose of `deltas`, to hold the increments +/// added to each multi-linear to compute the evaluation at the next point, and `evals_x` to hold +/// the current evaluation at $x$ in $\{2, ... , d_max\}$. +fn sumcheck_round( + eq_mu: &[E], + evaluator: &impl LogUpGkrEvaluator::BaseField>, + eq_ml: &MultiLinearPoly, + mls: &[MultiLinearPoly], + log_up_randomness: &[E], + r_sum_check: E, +) -> CompressedUnivariatePolyEvals { + let num_ml = mls.len(); + let num_vars = mls[0].num_variables(); + let num_rounds = num_vars - 1; + + #[cfg(not(feature = "concurrent"))] + let evaluations = { + let mut evals_one = vec![E::ZERO; num_ml]; + let mut evals_zero = vec![E::ZERO; num_ml]; + let mut evals_x = vec![E::ZERO; num_ml]; + let mut eq_x = E::ZERO; + + let mut deltas = vec![E::ZERO; num_ml]; + let mut eq_delta = E::ZERO; + + let mut numerators = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut denominators = vec![E::ZERO; evaluator.get_num_fractions()]; + (0..1 << num_rounds) + .map(|i| { + let mut total_evals = vec![E::ZERO; evaluator.max_degree()]; + + for (j, ml) in mls.iter().enumerate() { + evals_zero[j] = ml.evaluations()[2 * i]; + evals_one[j] = ml.evaluations()[2 * i + 1]; + } + + let eq_at_zero = eq_ml.evaluations()[2 * i]; + let eq_at_one = eq_ml.evaluations()[2 * i + 1]; + + // compute the evaluation at 1 + evaluator.evaluate_query( + &evals_one, + log_up_randomness, + &mut numerators, + &mut denominators, + ); + total_evals[0] = evaluate_composition_poly( + eq_mu, + &numerators, + &denominators, + eq_at_one, + r_sum_check, + ); + + // compute the evaluations at 2, ..., d_max points + for i in 0..num_ml { + deltas[i] = evals_one[i] - evals_zero[i]; + evals_x[i] = evals_one[i]; + } + eq_delta = eq_at_one - eq_at_zero; + eq_x = eq_at_one; + + for e in total_evals.iter_mut().skip(1) { + evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| { + *evx += *delta; + }); + eq_x += eq_delta; + + evaluator.evaluate_query( + &evals_x, + log_up_randomness, + &mut numerators, + &mut denominators, + ); + *e = evaluate_composition_poly( + eq_mu, + &numerators, + &denominators, + eq_x, + r_sum_check, + ); + } + + total_evals + }) + .fold(vec![E::ZERO; evaluator.max_degree()], |mut acc, poly_eval| { + acc.iter_mut().zip(poly_eval.iter()).for_each(|(a, b)| { + *a += *b; + }); + acc + }) + }; + + #[cfg(feature = "concurrent")] + let evaluations = (0..1 << num_rounds) + .into_par_iter() + .fold( + || { + ( + vec![E::ZERO; num_ml], + vec![E::ZERO; num_ml], + vec![E::ZERO; num_ml], + vec![E::ZERO; evaluator.max_degree()], + vec![E::ZERO; evaluator.get_num_fractions()], + vec![E::ZERO; evaluator.get_num_fractions()], + vec![E::ZERO; num_ml], + ) + }, + |( + mut evals_zero, + mut evals_one, + mut evals_x, + mut poly_evals, + mut numerators, + mut denominators, + mut deltas, + ), + i| { + for (j, ml) in mls.iter().enumerate() { + evals_zero[j] = ml.evaluations()[2 * i]; + evals_one[j] = ml.evaluations()[2 * i + 1]; + } + + let eq_at_zero = eq_ml.evaluations()[2 * i]; + let eq_at_one = eq_ml.evaluations()[2 * i + 1]; + + // compute the evaluation at 1 + evaluator.evaluate_query( + &evals_one, + log_up_randomness, + &mut numerators, + &mut denominators, + ); + poly_evals[0] = evaluate_composition_poly( + eq_mu, + &numerators, + &denominators, + eq_at_one, + r_sum_check, + ); + + // compute the evaluations at 2, ..., d_max points + for i in 0..num_ml { + deltas[i] = evals_one[i] - evals_zero[i]; + evals_x[i] = evals_one[i]; + } + let eq_delta = eq_at_one - eq_at_zero; + let mut eq_x = eq_at_one; + + for e in poly_evals.iter_mut().skip(1) { + evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| { + *evx += *delta; + }); + eq_x += eq_delta; + + evaluator.evaluate_query( + &evals_x, + log_up_randomness, + &mut numerators, + &mut denominators, + ); + *e = evaluate_composition_poly( + eq_mu, + &numerators, + &denominators, + eq_x, + r_sum_check, + ); + } + + (evals_zero, evals_one, evals_x, poly_evals, numerators, denominators, deltas) + }, + ) + .map(|(_, _, _, poly_evals, ..)| poly_evals) + .reduce( + || vec![E::ZERO; evaluator.max_degree()], + |mut acc, poly_eval| { + acc.iter_mut().zip(poly_eval.iter()).for_each(|(a, b)| { + *a += *b; + }); + acc + }, + ); + + CompressedUnivariatePolyEvals(evaluations.into()) +} + +/// Reduces an old claim to a new claim using the round challenge. +fn reduce_claim( + current_poly: &RoundProof, + current_round_claim: SumCheckRoundClaim, + round_challenge: E, +) -> SumCheckRoundClaim { + // evaluate the round polynomial at the round challenge to obtain the new claim + let new_claim = current_poly + .round_poly_coefs + .evaluate_using_claim(¤t_round_claim.claim, &round_challenge); + + // update the evaluation point using the round challenge + let mut new_partial_eval_point = current_round_claim.eval_point; + new_partial_eval_point.push(round_challenge); + + SumCheckRoundClaim { + eval_point: new_partial_eval_point, + claim: new_claim, + } +} diff --git a/sumcheck/src/prover/mod.rs b/sumcheck/src/prover/mod.rs new file mode 100644 index 000000000..13d35e551 --- /dev/null +++ b/sumcheck/src/prover/mod.rs @@ -0,0 +1,13 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +mod high_degree; +pub use high_degree::sum_check_prove_higher_degree; + +mod plain; +pub use plain::sumcheck_prove_plain; + +mod error; +pub use error::SumCheckProverError; diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs new file mode 100644 index 000000000..e0092cf10 --- /dev/null +++ b/sumcheck/src/prover/plain.rs @@ -0,0 +1,216 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use crypto::{ElementHasher, RandomCoin}; +use math::FieldElement; +#[cfg(feature = "concurrent")] +pub use rayon::prelude::*; +use smallvec::smallvec; + +use super::SumCheckProverError; +use crate::{ + comb_func, CompressedUnivariatePolyEvals, FinalOpeningClaim, MultiLinearPoly, RoundProof, + SumCheckProof, +}; + +/// Sum-check prover for non-linear multivariate polynomial of the simple LogUp-GKR. +/// +/// More specifically, the following function implements the logic of the sum-check prover as +/// described in Section 3.2 in [1], that is, given verifier challenges , the following implements +/// the sum-check prover for the following two statements +/// $$ +/// p_{\nu - \kappa}\left(v_{\kappa+1}, \cdots, v_{\nu}\right) = \sum_{w_i} +/// EQ\left(\left(v_{\kappa+1}, \cdots, v_{\nu}\right), \left(w_{\kappa+1}, \cdots, +/// w_{\nu}\right)\right) \cdot +/// \left( p_{\nu-\kappa+1}\left(1, w_{\kappa+1}, \cdots, w_{\nu}\right) \cdot +/// q_{\nu-\kappa+1}\left(0, w_{\kappa+1}, \cdots, w_{\nu}\right) + +/// p_{\nu-\kappa+1}\left(0, w_{\kappa+1}, \cdots, w_{\nu}\right) \cdot +/// q_{\nu-\kappa+1}\left(1, w_{\kappa+1}, \cdots, w_{\nu}\right)\right) +/// $$ +/// +/// and +/// +/// $$ +/// q_{\nu -k}\left(v_{\kappa+1}, \cdots, v_{\nu}\right) = \sum_{w_i}EQ\left(\left(v_{\kappa+1}, +/// \cdots, v_{\nu}\right), \left(w_{\kappa+1}, \cdots, w_{\nu }\right)\right) \cdot +/// \left( q_{\nu-\kappa+1}\left(1, w_{\kappa+1}, \cdots, w_{\nu}\right) \cdot +/// q_{\nu-\kappa+1}\left(0, w_{\kappa+1}, \cdots, w_{\nu}\right)\right) +/// $$ +/// +/// for $k = 1, \cdots, \nu - 1$ +/// +/// Instead of executing two runs of the sum-check protocol, a batching randomness `r_batch` is +/// sent by the verifier at the outset in order to batch the two statments. +/// +/// Note that the degree of the non-linear composition polynomial is 3. +/// +/// [1]: https://eprint.iacr.org/2023/1284 +#[allow(clippy::too_many_arguments)] +pub fn sumcheck_prove_plain>( + mut claim: E, + r_batch: E, + p: MultiLinearPoly, + q: MultiLinearPoly, + eq: &mut MultiLinearPoly, + transcript: &mut impl RandomCoin, +) -> Result, SumCheckProverError> { + let mut round_proofs = vec![]; + + let mut challenges = vec![]; + + // construct the vector of multi-linear polynomials + let (mut p0, mut p1) = p.project_least_significant_variable(); + let (mut q0, mut q1) = q.project_least_significant_variable(); + + for _ in 0..p0.num_variables() { + let len = p0.num_evaluations() / 2; + + #[cfg(not(feature = "concurrent"))] + let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = (0..len).fold( + (E::ZERO, E::ZERO, E::ZERO), + |(acc_point_1, acc_point_2, acc_point_3), i| { + let round_poly_eval_at_1 = comb_func( + p0[2 * i + 1], + p1[2 * i + 1], + q0[2 * i + 1], + q1[2 * i + 1], + eq[2 * i + 1], + r_batch, + ); + + let p0_delta = p0[2 * i + 1] - p0[2 * i]; + let p1_delta = p1[2 * i + 1] - p1[2 * i]; + let q0_delta = q0[2 * i + 1] - q0[2 * i]; + let q1_delta = q1[2 * i + 1] - q1[2 * i]; + let eq_delta = eq[2 * i + 1] - eq[2 * i]; + + let mut p0_eval_at_x = p0[2 * i + 1] + p0_delta; + let mut p1_eval_at_x = p1[2 * i + 1] + p1_delta; + let mut q0_eval_at_x = q0[2 * i + 1] + q0_delta; + let mut q1_eval_at_x = q1[2 * i + 1] + q1_delta; + let mut eq_evx = eq[2 * i + 1] + eq_delta; + let round_poly_eval_at_2 = comb_func( + p0_eval_at_x, + p1_eval_at_x, + q0_eval_at_x, + q1_eval_at_x, + eq_evx, + r_batch, + ); + + p0_eval_at_x += p0_delta; + p1_eval_at_x += p1_delta; + q0_eval_at_x += q0_delta; + q1_eval_at_x += q1_delta; + eq_evx += eq_delta; + let round_poly_eval_at_3 = comb_func( + p0_eval_at_x, + p1_eval_at_x, + q0_eval_at_x, + q1_eval_at_x, + eq_evx, + r_batch, + ); + + ( + round_poly_eval_at_1 + acc_point_1, + round_poly_eval_at_2 + acc_point_2, + round_poly_eval_at_3 + acc_point_3, + ) + }, + ); + + #[cfg(feature = "concurrent")] + let (round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3) = (0..len) + .into_par_iter() + .fold( + || (E::ZERO, E::ZERO, E::ZERO), + |(a, b, c), i| { + let round_poly_eval_at_1 = comb_func( + p0[2 * i + 1], + p1[2 * i + 1], + q0[2 * i + 1], + q1[2 * i + 1], + eq[2 * i + 1], + r_batch, + ); + + let p0_delta = p0[2 * i + 1] - p0[2 * i]; + let p1_delta = p1[2 * i + 1] - p1[2 * i]; + let q0_delta = q0[2 * i + 1] - q0[2 * i]; + let q1_delta = q1[2 * i + 1] - q1[2 * i]; + let eq_delta = eq[2 * i + 1] - eq[2 * i]; + + let mut p0_eval_at_x = p0[2 * i + 1] + p0_delta; + let mut p1_eval_at_x = p1[2 * i + 1] + p1_delta; + let mut q0_eval_at_x = q0[2 * i + 1] + q0_delta; + let mut q1_eval_at_x = q1[2 * i + 1] + q1_delta; + let mut eq_evx = eq[2 * i + 1] + eq_delta; + let round_poly_eval_at_2 = comb_func( + p0_eval_at_x, + p1_eval_at_x, + q0_eval_at_x, + q1_eval_at_x, + eq_evx, + r_batch, + ); + + p0_eval_at_x += p0_delta; + p1_eval_at_x += p1_delta; + q0_eval_at_x += q0_delta; + q1_eval_at_x += q1_delta; + eq_evx += eq_delta; + let round_poly_eval_at_3 = comb_func( + p0_eval_at_x, + p1_eval_at_x, + q0_eval_at_x, + q1_eval_at_x, + eq_evx, + r_batch, + ); + + (round_poly_eval_at_1 + a, round_poly_eval_at_2 + b, round_poly_eval_at_3 + c) + }, + ) + .reduce( + || (E::ZERO, E::ZERO, E::ZERO), + |(a0, b0, c0), (a1, b1, c1)| (a0 + a1, b0 + b1, c0 + c1), + ); + + let evals = smallvec![round_poly_eval_at_1, round_poly_eval_at_2, round_poly_eval_at_3]; + let compressed_round_poly_evals = CompressedUnivariatePolyEvals(evals); + let compressed_round_poly = compressed_round_poly_evals.to_poly(claim); + + // reseed with the s_i polynomial + transcript.reseed(H::hash_elements(&compressed_round_poly.0)); + let round_proof = RoundProof { + round_poly_coefs: compressed_round_poly.clone(), + }; + + let round_challenge = + transcript.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + + // fold each multi-linear using the round challenge + p0.bind_least_significant_variable(round_challenge); + p1.bind_least_significant_variable(round_challenge); + q0.bind_least_significant_variable(round_challenge); + q1.bind_least_significant_variable(round_challenge); + eq.bind_least_significant_variable(round_challenge); + + // compute the new reduced round claim + claim = compressed_round_poly.evaluate_using_claim(&claim, &round_challenge); + + round_proofs.push(round_proof); + challenges.push(round_challenge); + } + + Ok(SumCheckProof { + openings_claim: FinalOpeningClaim { + eval_point: challenges, + openings: vec![p0[0], p1[0], q0[0], q1[0]], + }, + round_proofs, + }) +} diff --git a/sumcheck/src/univariate.rs b/sumcheck/src/univariate.rs new file mode 100644 index 000000000..082a4daf9 --- /dev/null +++ b/sumcheck/src/univariate.rs @@ -0,0 +1,295 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use alloc::vec::Vec; + +use math::{batch_inversion, polynom, FieldElement}; +use smallvec::SmallVec; +use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; + +// CONSTANTS +// ================================================================================================ + +/// Maximum expected size of the round polynomials. This is needed for `SmallVec`. The size of +/// the round polynomials is dictated by the degree of the non-linearity in the sum-check statement +/// which is direcly influenced by the maximal degrees of the numerators and denominators appearing +/// in the LogUp-GKR relation and equal to one plus the maximal degree of the numerators and +/// maximal degree of denominators. +/// The following value assumes that this degree is at most 10. +const MAX_POLY_SIZE: usize = 10; + +// COMPRESSED UNIVARIATE POLYNOMIAL +// ================================================================================================ + +/// The coefficients of a univariate polynomial of degree n with the linear term coefficient +/// omitted. +/// +/// This compressed representation is useful during the sum-check protocol as the full uncompressed +/// representation can be recovered from the compressed one and the current sum-check round claim. +#[derive(Clone, Debug, PartialEq)] +pub struct CompressedUnivariatePoly(pub(crate) SmallVec<[E; MAX_POLY_SIZE]>); + +impl CompressedUnivariatePoly { + /// Evaluates a polynomial at a challenge point using a round claim. + /// + /// The round claim is used to recover the coefficient of the linear term using the relation + /// 2 * c0 + c1 + ... c_{n - 1} = claim. Using the complete list of coefficients, the polynomial + /// is then evaluated using Horner's method. + pub fn evaluate_using_claim(&self, claim: &E, challenge: &E) -> E { + // recover the coefficient of the linear term + let c1 = *claim - self.0.iter().fold(E::ZERO, |acc, term| acc + *term) - self.0[0]; + + // construct the full coefficient list + let mut complete_coefficients = vec![self.0[0], c1]; + complete_coefficients.extend_from_slice(&self.0[1..]); + + // evaluate + polynom::eval(&complete_coefficients, *challenge) + } +} + +impl Serializable for CompressedUnivariatePoly { + fn write_into(&self, target: &mut W) { + let vector: Vec = self.0.clone().into_vec(); + vector.write_into(target); + } +} + +impl Deserializable for CompressedUnivariatePoly +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + let vector: Vec = Vec::::read_from(source)?; + Ok(Self(vector.into())) + } +} + +/// The evaluations of a univariate polynomial of degree n at 0, 1, ..., n with the evaluation at 0 +/// omitted. +/// +/// This compressed representation is useful during the sum-check protocol as the full uncompressed +/// representation can be recovered from the compressed one and the current sum-check round claim. +#[derive(Clone, Debug)] +pub struct CompressedUnivariatePolyEvals(pub(crate) SmallVec<[E; MAX_POLY_SIZE]>); + +impl CompressedUnivariatePolyEvals { + /// Gives the coefficient representation of a polynomial represented in evaluation form. + /// + /// Since the evaluation at 0 is omitted, we need to use the round claim to recover + /// the evaluation at 0 using the identity $p(0) + p(1) = claim$. + /// Now, we have that for any polynomial $p(x) = c0 + c1 * x + ... + c_{n-1} * x^{n - 1}$: + /// + /// 1. $p(0) = c0$. + /// 2. $p(x) = c0 + x * q(x) where q(x) = c1 + ... + c_{n-1} * x^{n - 2}$. + /// + /// This means that we can compute the evaluations of q at 1, ..., n - 1 using the evaluations + /// of p and thus reduce by 1 the size of the interpolation problem. + /// Once the coefficient of q are recovered, the c0 coefficient is appended to these and this + /// is precisely the coefficient representation of the original polynomial q. + /// Note that the coefficient of the linear term is removed as this coefficient can be recovered + /// from the remaining coefficients, again, using the round claim using the relation + /// $2 * c0 + c1 + ... c_{n - 1} = claim$. + pub fn to_poly(&self, round_claim: E) -> CompressedUnivariatePoly { + // construct the vector of interpolation points 1, ..., n + let n_minus_1 = self.0.len(); + let points = (1..=n_minus_1 as u32).map(E::BaseField::from).collect::>(); + + // construct their inverses. These will be needed for computing the evaluations + // of the q polynomial as well as for doing the interpolation on q + let points_inv = batch_inversion(&points); + + // compute the zeroth coefficient + let c0 = round_claim - self.0[0]; + + // compute the evaluations of q + let q_evals: Vec = self + .0 + .iter() + .enumerate() + .map(|(i, evals)| (*evals - c0).mul_base(points_inv[i])) + .collect(); + + // interpolate q + let q_coefs = multiply_by_inverse_vandermonde(&q_evals, &points_inv); + + // append c0 to the coefficients of q to get the coefficients of p. The linear term + // coefficient is removed as this can be recovered from the other coefficients using + // the reduced claim. + let mut coefficients = SmallVec::with_capacity(self.0.len() + 1); + coefficients.push(c0); + coefficients.extend_from_slice(&q_coefs[1..]); + + CompressedUnivariatePoly(coefficients) + } +} + +// HELPER FUNCTIONS +// ================================================================================================ + +/// Given a (row) vector `v`, computes the vector-matrix product `v * V^{-1}` where `V` is +/// the Vandermonde matrix over the points `1, ..., n` where `n` is the length of `v`. +/// The resulting vector will then be the coefficients of the minimal interpolating polynomial +/// through the points `(i+1, v[i])` for `i` in `0, ..., n - 1` +/// +/// The naive way would be to invert the matrix `V` and then compute the vector-matrix product +/// this will cost `O(n^3)` operations and `O(n^2)` memory. We can also try Gaussian elimination +/// but this is also worst case `O(n^3)` operations and `O(n^2)` memory. +/// In the following implementation, we use the fact that the points over which we are interpolating +/// is a set of equidistant points and thus both the Vandermonde matrix and its inverse can be +/// described by sparse linear recurrence equations. +/// More specifically, we use the representation given in [1], where `V^{-1}` is represented as +/// `U * M` where: +/// +/// 1. `M` is a lower triangular matrix where its entries are given by M(i, j) = M(i - 1, j) - M(i - +/// 1, j - 1) / (i - 1) with boundary conditions M(i, 1) = 1 and M(i, j) = 0 when j > i. +/// +/// 2. `U` is an upper triangular (involutory) matrix where its entries are given by U(i, j) = U(i, +/// j - 1) - U(i - 1, j - 1) with boundary condition U(1, j) = 1 and U(i, j) = 0 when i > j. +/// +/// Note that the matrix indexing in the formulas above matches the one in the reference and starts +/// from 1. +/// +/// The above implies that we can do the vector-matrix multiplication in `O(n^2)` and using only +/// `O(n)` space. +/// +/// [1]: https://link.springer.com/article/10.1007/s002110050360 +fn multiply_by_inverse_vandermonde( + vector: &[E], + nodes_inv: &[E::BaseField], +) -> Vec { + let res = multiply_by_u(vector); + multiply_by_m(&res, nodes_inv) +} + +/// Multiplies a (row) vector `v` by an upper triangular matrix `U` to compute `v * U`. +/// +/// `U` is an upper triangular (involutory) matrix with its entries given by +/// U(i, j) = U(i, j - 1) - U(i - 1, j - 1) +/// with boundary condition U(1, j) = 1 and U(i, j) = 0 when i > j. +fn multiply_by_u(vector: &[E]) -> Vec { + let n = vector.len(); + let mut previous_u_col = vec![E::BaseField::ZERO; n]; + previous_u_col[0] = E::BaseField::ONE; + let mut current_u_col = vec![E::BaseField::ZERO; n]; + current_u_col[0] = E::BaseField::ONE; + + let mut result: Vec = vec![E::ZERO; n]; + for (i, res) in result.iter_mut().enumerate() { + *res = vector[0]; + + for (j, v) in vector.iter().enumerate().take(i + 1).skip(1) { + let u_entry: E::BaseField = + compute_u_entry::(j, &mut previous_u_col, &mut current_u_col); + *res += v.mul_base(u_entry); + } + previous_u_col.clone_from(¤t_u_col); + } + + result +} + +/// Multiplies a (row) vector `v` by a lower triangular matrix `M` to compute `v * M`. +/// +/// `M` is a lower triangular matrix with its entries given by +/// M(i, j) = M(i - 1, j) - M(i - 1, j - 1) / (i - 1) +/// with boundary conditions M(i, 1) = 1 and M(i, j) = 0 when j > i. +fn multiply_by_m(vector: &[E], nodes_inv: &[E::BaseField]) -> Vec { + let n = vector.len(); + let mut previous_m_col = vec![E::BaseField::ONE; n]; + let mut current_m_col = vec![E::BaseField::ZERO; n]; + current_m_col[0] = E::BaseField::ONE; + + let mut result: Vec = vec![E::ZERO; n]; + result[0] = vector.iter().fold(E::ZERO, |acc, term| acc + *term); + for (i, res) in result.iter_mut().enumerate().skip(1) { + current_m_col = vec![E::BaseField::ZERO; n]; + + for (j, v) in vector.iter().enumerate().skip(i) { + let m_entry: E::BaseField = + compute_m_entry::(j, &mut previous_m_col, &mut current_m_col, nodes_inv[j - 1]); + *res += v.mul_base(m_entry); + } + previous_m_col.clone_from(¤t_m_col); + } + + result +} + +/// Returns the j-th entry of the i-th column of matrix `U` given the values of the (i - 1)-th +/// column. The i-th column is also updated with the just computed `U(i, j)` entry. +/// +/// `U` is an upper triangular (involutory) matrix with its entries given by +/// U(i, j) = U(i, j - 1) - U(i - 1, j - 1) +/// with boundary condition U(1, j) = 1 and U(i, j) = 0 when i > j. +fn compute_u_entry( + j: usize, + col_prev: &mut [E::BaseField], + col_cur: &mut [E::BaseField], +) -> E::BaseField { + let value = col_prev[j] - col_prev[j - 1]; + col_cur[j] = value; + value +} + +/// Returns the j-th entry of the i-th column of matrix `M` given the values of the (i - 1)-th +/// and the i-th columns. The i-th column is also updated with the just computed `M(i, j)` entry. +/// +/// `M` is a lower triangular matrix with its entries given by +/// M(i, j) = M(i - 1, j) - M(i - 1, j - 1) / (i - 1) +/// with boundary conditions M(i, 1) = 1 and M(i, j) = 0 when j > i. +fn compute_m_entry( + j: usize, + col_previous: &mut [E::BaseField], + col_current: &mut [E::BaseField], + node_inv: E::BaseField, +) -> E::BaseField { + let value = col_current[j - 1] - node_inv * col_previous[j - 1]; + col_current[j] = value; + value +} + +// TESTS +// ================================================================================================ + +#[test] +fn test_poly_partial() { + use math::fields::f64::BaseElement; + + let degree = 1000; + let mut points: Vec = vec![BaseElement::ZERO; degree]; + points + .iter_mut() + .enumerate() + .for_each(|(i, node)| *node = BaseElement::from(i as u32)); + + let p: Vec = rand_utils::rand_vector(degree); + let evals = polynom::eval_many(&p, &points); + + let mut partial_evals = evals.clone(); + partial_evals.remove(0); + + let partial_poly = CompressedUnivariatePolyEvals(partial_evals.into()); + let claim = evals[0] + evals[1]; + let poly_coeff = partial_poly.to_poly(claim); + + let r = rand_utils::rand_vector(1); + + assert_eq!(polynom::eval(&p, r[0]), poly_coeff.evaluate_using_claim(&claim, &r[0])) +} + +#[test] +fn test_serialization() { + use math::fields::f64::BaseElement; + + let original_poly = + CompressedUnivariatePoly(rand_utils::rand_array::().into()); + let poly_bytes = original_poly.to_bytes(); + + let deserialized_poly = + CompressedUnivariatePoly::::read_from_bytes(&poly_bytes).unwrap(); + + assert_eq!(original_poly, deserialized_poly) +} diff --git a/sumcheck/src/verifier/mod.rs b/sumcheck/src/verifier/mod.rs new file mode 100644 index 000000000..d1cfae3a4 --- /dev/null +++ b/sumcheck/src/verifier/mod.rs @@ -0,0 +1,149 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use alloc::vec::Vec; + +use air::LogUpGkrEvaluator; +use crypto::{ElementHasher, RandomCoin}; +use math::FieldElement; + +use crate::{ + comb_func, evaluate_composition_poly, EqFunction, FinalLayerProof, FinalOpeningClaim, + RoundProof, SumCheckProof, SumCheckRoundClaim, +}; + +/// Verifies sum-check proofs, as part of the GKR proof, for all GKR layers except for the last one +/// i.e., the circuit input layer. +pub fn verify_sum_check_intermediate_layers< + E: FieldElement, + H: ElementHasher, +>( + proof: &SumCheckProof, + gkr_eval_point: &[E], + claim: (E, E), + transcript: &mut impl RandomCoin, +) -> Result, SumCheckVerifierError> { + // generate challenge to batch sum-checks + transcript.reseed(H::hash_elements(&[claim.0, claim.1])); + let r_batch: E = transcript + .draw() + .map_err(|_| SumCheckVerifierError::FailedToGenerateChallenge)?; + + // compute the claim for the batched sum-check + let reduced_claim = claim.0 + claim.1 * r_batch; + + let SumCheckProof { openings_claim, round_proofs } = proof; + + let final_round_claim = verify_rounds(reduced_claim, round_proofs, transcript)?; + assert_eq!(openings_claim.eval_point, final_round_claim.eval_point); + + let p0 = openings_claim.openings[0]; + let p1 = openings_claim.openings[1]; + let q0 = openings_claim.openings[2]; + let q1 = openings_claim.openings[3]; + + let eq = EqFunction::new(gkr_eval_point.into()).evaluate(&openings_claim.eval_point); + + if comb_func(p0, p1, q0, q1, eq, r_batch) != final_round_claim.claim { + return Err(SumCheckVerifierError::FinalEvaluationCheckFailed); + } + + Ok(openings_claim.clone()) +} + +/// Verifies the final sum-check proof i.e., the one for the input layer, including the final check, +/// and returns a [`FinalOpeningClaim`] to the STARK verifier in order to verify the correctness of +/// the openings. +pub fn verify_sum_check_input_layer>( + evaluator: &impl LogUpGkrEvaluator, + proof: &FinalLayerProof, + log_up_randomness: Vec, + gkr_eval_point: &[E], + claim: (E, E), + transcript: &mut impl RandomCoin, +) -> Result, SumCheckVerifierError> { + let FinalLayerProof { proof } = proof; + + // generate challenge to batch sum-checks + transcript.reseed(H::hash_elements(&[claim.0, claim.1])); + let r_batch: E = transcript + .draw() + .map_err(|_| SumCheckVerifierError::FailedToGenerateChallenge)?; + + // compute the claim for the batched sum-check + let reduced_claim = claim.0 + claim.1 * r_batch; + + // verify the sum-check proof + let SumCheckRoundClaim { eval_point, claim } = + verify_rounds(reduced_claim, &proof.round_proofs, transcript)?; + + // execute the final evaluation check + if proof.openings_claim.eval_point != eval_point { + return Err(SumCheckVerifierError::WrongOpeningPoint); + } + + let mut numerators = vec![E::ZERO; evaluator.get_num_fractions()]; + let mut denominators = vec![E::ZERO; evaluator.get_num_fractions()]; + evaluator.evaluate_query( + &proof.openings_claim.openings, + &log_up_randomness, + &mut numerators, + &mut denominators, + ); + + let mu = evaluator.get_num_fractions().trailing_zeros() - 1; + let (evaluation_point_mu, evaluation_point_nu) = gkr_eval_point.split_at(mu as usize); + + let eq_mu = EqFunction::new(evaluation_point_mu.into()).evaluations(); + let eq_nu = EqFunction::new(evaluation_point_nu.into()); + + let eq_nu_eval = eq_nu.evaluate(&proof.openings_claim.eval_point); + let expected_evaluation = + evaluate_composition_poly(&eq_mu, &numerators, &denominators, eq_nu_eval, r_batch); + + if expected_evaluation != claim { + Err(SumCheckVerifierError::FinalEvaluationCheckFailed) + } else { + Ok(proof.openings_claim.clone()) + } +} + +/// Verifies a round of the sum-check protocol without executing the final check. +fn verify_rounds( + claim: E, + round_proofs: &[RoundProof], + coin: &mut impl RandomCoin, +) -> Result, SumCheckVerifierError> +where + E: FieldElement, + H: ElementHasher, +{ + let mut round_claim = claim; + let mut evaluation_point = vec![]; + for round_proof in round_proofs { + let round_poly_coefs = round_proof.round_poly_coefs.clone(); + coin.reseed(H::hash_elements(&round_poly_coefs.0)); + + let r = coin.draw().map_err(|_| SumCheckVerifierError::FailedToGenerateChallenge)?; + + round_claim = round_proof.round_poly_coefs.evaluate_using_claim(&round_claim, &r); + evaluation_point.push(r); + } + + Ok(SumCheckRoundClaim { + eval_point: evaluation_point, + claim: round_claim, + }) +} + +#[derive(Debug, thiserror::Error)] +pub enum SumCheckVerifierError { + #[error("the final evaluation check of sum-check failed")] + FinalEvaluationCheckFailed, + #[error("failed to generate round challenge")] + FailedToGenerateChallenge, + #[error("wrong opening point for the oracles")] + WrongOpeningPoint, +} From aef404d4bcd11dce2cbe26e71450b2eead032411 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:06:38 +0200 Subject: [PATCH 02/19] Implement GKR backend for LogUp-GKR (#296) --- .cargo/katex-header.html | 1 + air/src/air/aux.rs | 161 +++---- air/src/air/boundary/mod.rs | 10 +- air/src/air/coefficients.rs | 24 +- air/src/air/context.rs | 70 ++- air/src/air/logup_gkr.rs | 117 ++++- air/src/air/mod.rs | 56 ++- air/src/air/tests.rs | 22 +- air/src/air/trace_info.rs | 29 +- air/src/air/transition/mod.rs | 2 +- air/src/lib.rs | 2 +- air/src/proof/context.rs | 11 +- air/src/proof/ood_frame.rs | 2 + crypto/src/merkle/concurrent.rs | 7 +- examples/src/fibonacci/fib2/air.rs | 8 +- examples/src/fibonacci/fib8/air.rs | 8 +- examples/src/fibonacci/fib_small/air.rs | 8 +- examples/src/fibonacci/mulfib2/air.rs | 8 +- examples/src/fibonacci/mulfib8/air.rs | 8 +- examples/src/lamport/aggregate/air.rs | 8 +- examples/src/lamport/threshold/air.rs | 10 +- examples/src/merkle/air.rs | 9 +- examples/src/rescue/air.rs | 9 +- examples/src/rescue_raps/air.rs | 9 +- .../src/rescue_raps/custom_trace_table.rs | 6 +- examples/src/rescue_raps/prover.rs | 29 +- examples/src/utils/rescue.rs | 2 + examples/src/vdf/exempt/air.rs | 10 +- examples/src/vdf/regular/air.rs | 8 +- math/src/field/f64/mod.rs | 7 +- prover/Cargo.toml | 2 + prover/benches/lagrange_kernel.rs | 52 +-- prover/src/constraints/evaluator/default.rs | 6 +- .../constraints/evaluator/periodic_table.rs | 6 +- prover/src/errors.rs | 5 + prover/src/lib.rs | 105 +++-- prover/src/logup_gkr/mod.rs | 403 ++++++++++++++++++ prover/src/logup_gkr/prover.rs | 256 +++++++++++ prover/src/tests/mod.rs | 12 +- prover/src/trace/mod.rs | 20 +- prover/src/trace/tests.rs | 2 +- prover/src/trace/trace_lde/default/tests.rs | 2 +- prover/src/trace/trace_table.rs | 4 +- sumcheck/benches/sum_check_high_degree.rs | 26 +- sumcheck/src/lib.rs | 19 +- sumcheck/src/prover/high_degree.rs | 6 +- sumcheck/src/verifier/mod.rs | 14 +- verifier/Cargo.toml | 2 + verifier/src/channel.rs | 17 +- verifier/src/evaluator.rs | 23 +- verifier/src/lib.rs | 57 +-- verifier/src/logup_gkr/mod.rs | 115 +++++ winterfell/src/lib.rs | 37 +- winterfell/src/tests.rs | 281 ++++++------ 54 files changed, 1556 insertions(+), 577 deletions(-) create mode 100644 prover/src/logup_gkr/mod.rs create mode 100644 prover/src/logup_gkr/prover.rs create mode 100644 verifier/src/logup_gkr/mod.rs diff --git a/.cargo/katex-header.html b/.cargo/katex-header.html index 5db5bc0b1..ca338654e 100644 --- a/.cargo/katex-header.html +++ b/.cargo/katex-header.html @@ -11,6 +11,7 @@ renderMathInElement(document.body, { fleqn: false, macros: { + "\\B": "\\mathbb{B}", "\\F": "\\mathbb{F}", "\\G": "\\mathbb{G}", "\\O": "\\mathcal{O}", diff --git a/air/src/air/aux.rs b/air/src/air/aux.rs index 01f59035a..7dc9c4f48 100644 --- a/air/src/air/aux.rs +++ b/air/src/air/aux.rs @@ -3,39 +3,32 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use alloc::{string::ToString, vec::Vec}; +use alloc::vec::Vec; -use crypto::{ElementHasher, RandomCoin, RandomCoinError}; use math::FieldElement; -use utils::Deserializable; -use super::lagrange::LagrangeKernelRandElements; +use super::{lagrange::LagrangeKernelRandElements, LogUpGkrOracle}; -/// Holds the randomly generated elements necessary to build the auxiliary trace. +/// Holds the randomly generated elements used in defining the auxiliary segment of the trace. /// -/// Specifically, [`AuxRandElements`] currently supports 3 types of random elements: -/// - the ones needed to build the Lagrange kernel column (when using GKR to accelerate LogUp), -/// - the ones needed to build the "s" auxiliary column (when using GKR to accelerate LogUp), -/// - the ones needed to build all the other auxiliary columns +/// Specifically, [`AuxRandElements`] currently supports 2 types of random elements: +/// - the ones needed to build all the auxiliary columns except for the ones associated +/// to LogUp-GKR. +/// - the ones needed to build the "s" and Lagrange kernel auxiliary columns (when using GKR to +/// accelerate LogUp). These also include additional information needed to evaluate constraints +/// one these two columns. #[derive(Debug, Clone)] -pub struct AuxRandElements { +pub struct AuxRandElements { rand_elements: Vec, - gkr: Option>, + gkr: Option>, } -impl AuxRandElements { - /// Creates a new [`AuxRandElements`], where the auxiliary trace doesn't contain a Lagrange - /// kernel column. - pub fn new(rand_elements: Vec) -> Self { - Self { rand_elements, gkr: None } - } - - /// Creates a new [`AuxRandElements`], where the auxiliary trace contains columns needed when +impl AuxRandElements { + /// Creates a new [`AuxRandElements`], where the auxiliary segment may contain columns needed when /// using GKR to accelerate LogUp (i.e. a Lagrange kernel column and the "s" column). - pub fn new_with_gkr(rand_elements: Vec, gkr: GkrRandElements) -> Self { - Self { rand_elements, gkr: Some(gkr) } + pub fn new(rand_elements: Vec, gkr: Option>) -> Self { + Self { rand_elements, gkr } } - /// Returns the random elements needed to build all columns other than the two GKR-related ones. pub fn rand_elements(&self) -> &[E] { &self.rand_elements @@ -43,7 +36,7 @@ impl AuxRandElements { /// Returns the random elements needed to build the Lagrange kernel column. pub fn lagrange(&self) -> Option<&LagrangeKernelRandElements> { - self.gkr.as_ref().map(|gkr| &gkr.lagrange) + self.gkr.as_ref().map(|gkr| &gkr.lagrange_kernel_eval_point) } /// Returns the random values used to linearly combine the openings returned from the GKR proof. @@ -52,83 +45,97 @@ impl AuxRandElements { pub fn gkr_openings_combining_randomness(&self) -> Option<&[E]> { self.gkr.as_ref().map(|gkr| gkr.openings_combining_randomness.as_ref()) } + + /// Returns a collection of data necessary for implementing the univariate IOP for multi-linear + /// evaluations of [1] when LogUp-GKR is enabled, else returns a `None`. + /// + /// [1]: https://eprint.iacr.org/2023/1284 + pub fn gkr_data(&self) -> Option> { + self.gkr.clone() + } } -/// Holds all the random elements needed when using GKR to accelerate LogUp. +/// Holds all the data needed when using LogUp-GKR in order to build and verify the correctness of +/// two extra auxiliary columns required for running the univariate IOP for multi-linear +/// evaluations of [1]. /// -/// This consists of two sets of random values: -/// 1. The Lagrange kernel random elements (expanded on in [`LagrangeKernelRandElements`]), and +/// This consists of: +/// 1. The Lagrange kernel random elements (expanded on in [`LagrangeKernelRandElements`]). These +/// make up the evaluation point of the multi-linear extension polynomials underlying the oracles +/// in point 4 below. /// 2. The "openings combining randomness". +/// 3. The openings of the multi-linear extension polynomials of the main trace columns involved +/// in LogUp. +/// 4. A description of the each of the oracles involved in LogUp. /// -/// After the verifying the LogUp-GKR circuit, the verifier is left with unproven claims provided -/// nondeterministically by the prover about the evaluations of the MLE of the main trace columns at -/// the Lagrange kernel random elements. Those claims are (linearly) combined into one using the -/// openings combining randomness. +/// After verifying the LogUp-GKR circuit, the verifier is left with unproven claims provided +/// by the prover about the evaluations of the MLEs of the main trace columns at the evaluation +/// point defining the Lagrange kernel. Those claims are (linearly) batched into one using the +/// openings combining randomness and checked against the batched oracles using univariate IOP +/// for multi-linear evaluations of [1]. +/// +/// [1]: https://eprint.iacr.org/2023/1284 #[derive(Clone, Debug)] -pub struct GkrRandElements { - lagrange: LagrangeKernelRandElements, - openings_combining_randomness: Vec, +pub struct GkrData { + pub lagrange_kernel_eval_point: LagrangeKernelRandElements, + pub openings_combining_randomness: Vec, + pub openings: Vec, + pub oracles: Vec>, } -impl GkrRandElements { - /// Constructs a new [`GkrRandElements`] from [`LagrangeKernelRandElements`], and the openings - /// combining randomness. +impl GkrData { + /// Constructs a new [`GkrData`] from [`LagrangeKernelRandElements`], the openings combining + /// randomness and the LogUp-GKR oracles. /// - /// See [`GkrRandElements`] for a more detailed description. + /// See [`GkrData`] for a more detailed description. pub fn new( - lagrange: LagrangeKernelRandElements, + lagrange_kernel_eval_point: LagrangeKernelRandElements, openings_combining_randomness: Vec, + openings: Vec, + oracles: Vec>, ) -> Self { - Self { lagrange, openings_combining_randomness } + Self { + lagrange_kernel_eval_point, + openings_combining_randomness, + openings, + oracles, + } } /// Returns the random elements needed to build the Lagrange kernel column. pub fn lagrange_kernel_rand_elements(&self) -> &LagrangeKernelRandElements { - &self.lagrange + &self.lagrange_kernel_eval_point } /// Returns the random values used to linearly combine the openings returned from the GKR proof. pub fn openings_combining_randomness(&self) -> &[E] { &self.openings_combining_randomness } -} -/// A trait for verifying a GKR proof. -/// -/// Specifically, the use case in mind is proving the constraints of a LogUp bus using GKR, as -/// described in [Improving logarithmic derivative lookups using -/// GKR](https://eprint.iacr.org/2023/1284.pdf). -pub trait GkrVerifier { - /// The GKR proof. - type GkrProof: Deserializable; - /// The error that can occur during GKR proof verification. - type Error: ToString; - - /// Verifies the GKR proof, and returns the random elements that were used in building - /// the Lagrange kernel auxiliary column. - fn verify( - &self, - gkr_proof: Self::GkrProof, - public_coin: &mut impl RandomCoin, - ) -> Result, Self::Error> - where - E: FieldElement, - Hasher: ElementHasher; -} + pub fn openings(&self) -> &[E] { + &self.openings + } + + pub fn oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + pub fn compute_batched_claim(&self) -> E { + self.openings[0] + + self + .openings + .iter() + .skip(1) + .zip(self.openings_combining_randomness.iter()) + .fold(E::ZERO, |acc, (a, b)| acc + *a * *b) + } -impl GkrVerifier for () { - type GkrProof = (); - type Error = RandomCoinError; - - fn verify( - &self, - _gkr_proof: Self::GkrProof, - _public_coin: &mut impl RandomCoin, - ) -> Result, Self::Error> - where - E: FieldElement, - Hasher: ElementHasher, - { - Ok(GkrRandElements::new(LagrangeKernelRandElements::default(), Vec::new())) + pub fn compute_batched_query(&self, query: &[E::BaseField]) -> E { + E::from(query[0]) + + query + .iter() + .skip(1) + .zip(self.openings_combining_randomness.iter()) + .fold(E::ZERO, |acc, (a, b)| acc + b.mul_base(*a)) } } diff --git a/air/src/air/boundary/mod.rs b/air/src/air/boundary/mod.rs index 7f92c80ab..2c15ac5a3 100644 --- a/air/src/air/boundary/mod.rs +++ b/air/src/air/boundary/mod.rs @@ -58,8 +58,8 @@ impl BoundaryConstraints { /// coefficients. /// * The specified assertions are not valid in the context of the computation (e.g., assertion /// column index is out of bounds). - pub fn new( - context: &AirContext, + pub fn new

( + context: &AirContext, main_assertions: Vec>, aux_assertions: Vec>, composition_coefficients: &[E], @@ -88,7 +88,7 @@ impl BoundaryConstraints { ); let trace_length = context.trace_info.length(); - let main_trace_width = context.trace_info.main_trace_width(); + let main_trace_width = context.trace_info.main_segment_width(); let aux_trace_width = context.trace_info.aux_segment_width(); // make sure the assertions are valid in the context of their respective trace segments; @@ -151,9 +151,9 @@ impl BoundaryConstraints { /// Translates the provided assertions into boundary constraints, groups the constraints by their /// divisor, and sorts the resulting groups by the degree adjustment factor. -fn group_constraints( +fn group_constraints( assertions: Vec>, - context: &AirContext, + context: &AirContext, composition_coefficients: &[E], inv_g: F::BaseField, twiddle_map: &mut BTreeMap>, diff --git a/air/src/air/coefficients.rs b/air/src/air/coefficients.rs index b82b2ac6b..ed6c3fa99 100644 --- a/air/src/air/coefficients.rs +++ b/air/src/air/coefficients.rs @@ -27,11 +27,19 @@ use math::FieldElement; /// /// The coefficients are separated into two lists: one for transition constraints and another one /// for boundary constraints. This separation is done for convenience only. +/// +/// In addition to the above, and when LogUp-GKR is enabled, there are two extra sets of +/// constraint composition coefficients that are used, namely for: +/// +/// 1. Lagrange kernel constraints, which include both transition and boundary constraints. +/// 2. S-column constraint, which is used in implementing the cohomological sum-check argument +/// of https://eprint.iacr.org/2021/930 #[derive(Debug, Clone)] pub struct ConstraintCompositionCoefficients { pub transition: Vec, pub boundary: Vec, pub lagrange: Option>, + pub s_col: Option, } /// Stores the constraint composition coefficients for the Lagrange kernel transition and boundary @@ -83,8 +91,9 @@ pub struct LagrangeConstraintsCompositionCoefficients { /// negligible increase in soundness error. The formula for the updated error can be found in /// Theorem 8 of https://eprint.iacr.org/2022/1216. /// -/// In the case when the trace polynomials contain a trace polynomial corresponding to a Lagrange -/// kernel column, the above expression of $Y(x)$ includes the additional term given by +/// In the case when LogUp-GKR is enabled, the trace polynomials contain an additional trace +/// polynomial corresponding to a Lagrange kernel column and the above expression of $Y(x)$ +/// includes the additional term given by /// /// $$ /// \gamma \cdot \frac{T_l(x) - p_S(x)}{Z_S(x)} @@ -99,8 +108,13 @@ pub struct LagrangeConstraintsCompositionCoefficients { /// 4. $p_S(X)$ is the polynomial of minimal degree interpolating the set ${(a, T_l(a)): a \in S}$. /// 5. $Z_S(X)$ is the polynomial of minimal degree vanishing over the set $S$. /// -/// Note that, if a Lagrange kernel trace polynomial is present, then $\rho^{+}$ from above should -/// be updated to be $\rho^{+} := \frac{\kappa + log_2(\nu) + 1}{\nu}$. +/// Note that when LogUp-GKR is enabled, we also have to take into account an additional column, +/// called s-column throughout, used in implementing the univariate IOP for multi-linear evaluation. +/// This means that we need and additional random value, in addition to $\gamma$ above, when +/// LogUp-GKR is enabled. +/// +/// Note that, when LogUp-GKR is enabled, $\rho^{+}$ from above should be updated to be +/// $\rho^{+} := \frac{\kappa + log_2(\nu) + 1}{\nu}$. #[derive(Debug, Clone)] pub struct DeepCompositionCoefficients { /// Trace polynomial composition coefficients $\alpha_i$. @@ -109,4 +123,6 @@ pub struct DeepCompositionCoefficients { pub constraints: Vec, /// Lagrange kernel trace polynomial composition coefficient $\gamma$. pub lagrange: Option, + /// S-column trace polynomial composition coefficient. + pub s_col: Option, } diff --git a/air/src/air/context.rs b/air/src/air/context.rs index 183f575fc..c36173ca3 100644 --- a/air/src/air/context.rs +++ b/air/src/air/context.rs @@ -14,21 +14,22 @@ use crate::{air::TransitionConstraintDegree, ProofOptions, TraceInfo}; // ================================================================================================ /// STARK parameters and trace properties for a specific execution of a computation. #[derive(Clone, PartialEq, Eq)] -pub struct AirContext { +pub struct AirContext { pub(super) options: ProofOptions, pub(super) trace_info: TraceInfo, + pub(super) pub_inputs: P, pub(super) main_transition_constraint_degrees: Vec, pub(super) aux_transition_constraint_degrees: Vec, pub(super) num_main_assertions: usize, pub(super) num_aux_assertions: usize, - pub(super) lagrange_kernel_aux_column_idx: Option, pub(super) ce_blowup_factor: usize, pub(super) trace_domain_generator: B, pub(super) lde_domain_generator: B, pub(super) num_transition_exemptions: usize, + pub(super) logup_gkr: bool, } -impl AirContext { +impl AirContext { // CONSTRUCTORS // -------------------------------------------------------------------------------------------- /// Returns a new instance of [AirContext] instantiated for computations which require a single @@ -48,6 +49,7 @@ impl AirContext { /// * `trace_info` describes a multi-segment execution trace. pub fn new( trace_info: TraceInfo, + pub_inputs: P, transition_constraint_degrees: Vec, num_assertions: usize, options: ProofOptions, @@ -58,11 +60,11 @@ impl AirContext { ); Self::new_multi_segment( trace_info, + pub_inputs, transition_constraint_degrees, Vec::new(), num_assertions, 0, - None, options, ) } @@ -91,11 +93,11 @@ impl AirContext { /// of the specified transition constraints. pub fn new_multi_segment( trace_info: TraceInfo, + pub_inputs: P, main_transition_constraint_degrees: Vec, aux_transition_constraint_degrees: Vec, num_main_assertions: usize, num_aux_assertions: usize, - lagrange_kernel_aux_column_idx: Option, options: ProofOptions, ) -> Self { assert!( @@ -104,11 +106,11 @@ impl AirContext { ); assert!(num_main_assertions > 0, "at least one assertion must be specified"); - if trace_info.is_multi_segment() { + if trace_info.is_multi_segment() && !trace_info.logup_gkr_enabled() { assert!( - !aux_transition_constraint_degrees.is_empty(), - "at least one transition constraint degree must be specified for the auxiliary trace segment" - ); + !aux_transition_constraint_degrees.is_empty(), + "at least one transition constraint degree must be specified for the auxiliary trace segment" + ); assert!( num_aux_assertions > 0, "at least one assertion must be specified against the auxiliary trace segment" @@ -124,15 +126,6 @@ impl AirContext { ); } - // validate Lagrange kernel aux column, if any - if let Some(lagrange_kernel_aux_column_idx) = lagrange_kernel_aux_column_idx { - assert!( - lagrange_kernel_aux_column_idx == trace_info.get_aux_segment_width() - 1, - "Lagrange kernel column should be the last column of the auxiliary trace: index={}, but aux trace width is {}", - lagrange_kernel_aux_column_idx, trace_info.get_aux_segment_width() - ); - } - // determine minimum blowup factor needed to evaluate transition constraints by taking // the blowup factor of the highest degree constraint let mut ce_blowup_factor = 0; @@ -161,18 +154,41 @@ impl AirContext { AirContext { options, trace_info, + pub_inputs, main_transition_constraint_degrees, aux_transition_constraint_degrees, num_main_assertions, num_aux_assertions, - lagrange_kernel_aux_column_idx, ce_blowup_factor, trace_domain_generator: B::get_root_of_unity(trace_length.ilog2()), lde_domain_generator: B::get_root_of_unity(lde_domain_size.ilog2()), num_transition_exemptions: 1, + logup_gkr: false, } } + pub fn with_logup_gkr( + trace_info: TraceInfo, + pub_inputs: P, + main_transition_constraint_degrees: Vec, + aux_transition_constraint_degrees: Vec, + num_main_assertions: usize, + num_aux_assertions: usize, + options: ProofOptions, + ) -> Self { + let mut air_context = Self::new_multi_segment( + trace_info, + pub_inputs, + main_transition_constraint_degrees, + aux_transition_constraint_degrees, + num_main_assertions, + num_aux_assertions, + options, + ); + air_context.logup_gkr = true; + air_context + } + // PUBLIC ACCESSORS // -------------------------------------------------------------------------------------------- @@ -209,6 +225,10 @@ impl AirContext { self.trace_info.length() * self.options.blowup_factor() } + pub fn public_inputs(&self) -> &P { + &self.pub_inputs + } + /// Returns the number of transition constraints for a computation, excluding the Lagrange /// kernel transition constraints, which are managed separately. /// @@ -232,12 +252,16 @@ impl AirContext { /// Returns the index of the auxiliary column which implements the Lagrange kernel, if any pub fn lagrange_kernel_aux_column_idx(&self) -> Option { - self.lagrange_kernel_aux_column_idx + if self.logup_gkr_enabled() { + Some(self.trace_info().aux_segment_width() - 1) + } else { + None + } } - /// Returns true if the auxiliary trace segment contains a Lagrange kernel column - pub fn has_lagrange_kernel_aux_column(&self) -> bool { - self.lagrange_kernel_aux_column_idx().is_some() + /// Returns true if LogUp-GKR is enabled. + pub fn logup_gkr_enabled(&self) -> bool { + self.logup_gkr } /// Returns the total number of assertions defined for a computation, excluding the Lagrange diff --git a/air/src/air/logup_gkr.rs b/air/src/air/logup_gkr.rs index 98054c938..0438064d9 100644 --- a/air/src/air/logup_gkr.rs +++ b/air/src/air/logup_gkr.rs @@ -4,10 +4,12 @@ // LICENSE file in the root directory of this source tree. use alloc::vec::Vec; +use core::marker::PhantomData; +use crypto::{ElementHasher, RandomCoin}; use math::{ExtensionOf, FieldElement, StarkField, ToElements}; -use super::EvaluationFrame; +use super::{EvaluationFrame, GkrData, LagrangeKernelRandElements}; /// A trait containing the necessary information in order to run the LogUp-GKR protocol of [1]. /// @@ -25,7 +27,7 @@ pub trait LogUpGkrEvaluator: Clone + Sync { /// Gets a list of all oracles involved in LogUp-GKR; this is intended to be used in construction of /// MLEs. - fn get_oracles(&self) -> Vec>; + fn get_oracles(&self) -> &[LogUpGkrOracle]; /// Returns the number of random values needed to evaluate a query. fn get_num_rand_values(&self) -> usize; @@ -79,11 +81,122 @@ pub trait LogUpGkrEvaluator: Clone + Sync { { E::ZERO } + + /// Generates the data needed for running the univariate IOP for multi-linear evaluation of [1]. + /// + /// This mainly generates the batching randomness used to batch a number of multi-linear + /// evaluation claims and includes some additional data that is needed for building/verifying + /// the univariate IOP for multi-linear evaluation of [1]. + /// + /// This is the $\lambda$ randomness in section 5.2 in [1] but using different random values for + /// each term instead of powers of a single random element. + /// + /// [1]: https://eprint.iacr.org/2023/1284 + fn generate_univariate_iop_for_multi_linear_opening_data( + &self, + openings: Vec, + eval_point: Vec, + public_coin: &mut impl RandomCoin, + ) -> GkrData + where + E: FieldElement, + H: ElementHasher, + { + public_coin.reseed(H::hash_elements(&openings)); + + let mut batching_randomness = Vec::with_capacity(openings.len() - 1); + for _ in 0..openings.len() - 1 { + batching_randomness.push(public_coin.draw().expect("failed to generate randomness")) + } + + GkrData::new( + LagrangeKernelRandElements::new(eval_point), + batching_randomness, + openings, + self.get_oracles().to_vec(), + ) + } +} + +#[derive(Clone, Default)] +pub(crate) struct PhantomLogUpGkrEval> { + _field: PhantomData, + _public_inputs: PhantomData

, +} + +impl PhantomLogUpGkrEval +where + B: StarkField, + P: Clone + Send + Sync + ToElements, +{ + pub fn new() -> Self { + Self { + _field: PhantomData, + _public_inputs: PhantomData, + } + } +} + +impl LogUpGkrEvaluator for PhantomLogUpGkrEval +where + B: StarkField, + P: Clone + Send + Sync + ToElements, +{ + type BaseField = B; + + type PublicInputs = P; + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") + } + + fn get_num_rand_values(&self) -> usize { + panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") + } + + fn get_num_fractions(&self) -> usize { + panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") + } + + fn max_degree(&self) -> usize { + panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") + } + + fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E], _query: &mut [E]) + where + E: FieldElement, + { + panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") + } + + fn evaluate_query( + &self, + _query: &[F], + _rand_values: &[E], + _numerator: &mut [E], + _denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") + } + + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") + } } #[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord)] pub enum LogUpGkrOracle { + /// A column with a given index in the main trace segment. CurrentRow(usize), + /// A column with a given index in the main trace segment but shifted upwards. NextRow(usize), + /// A virtual periodic column defined by its values in a given cycle. Note that the cycle length + /// must be a power of 2. PeriodicValue(Vec), } diff --git a/air/src/air/mod.rs b/air/src/air/mod.rs index 07f38cce1..5dcee0717 100644 --- a/air/src/air/mod.rs +++ b/air/src/air/mod.rs @@ -6,12 +6,13 @@ use alloc::{collections::BTreeMap, vec::Vec}; use crypto::{RandomCoin, RandomCoinError}; +use logup_gkr::PhantomLogUpGkrEval; use math::{fft, ExtensibleField, ExtensionOf, FieldElement, StarkField, ToElements}; use crate::ProofOptions; mod aux; -pub use aux::{AuxRandElements, GkrRandElements, GkrVerifier}; +pub use aux::{AuxRandElements, GkrData}; mod trace_info; pub use trace_info::TraceInfo; @@ -45,7 +46,6 @@ pub use coefficients::{ mod divisor; pub use divisor::ConstraintDivisor; -use utils::{Deserializable, Serializable}; #[cfg(test)] mod tests; @@ -195,13 +195,7 @@ pub trait Air: Send + Sync { /// A type defining shape of public inputs for the computation described by this protocol. /// This could be any type as long as it can be serialized into a sequence of field elements. - type PublicInputs: ToElements + Send; - - /// An GKR proof object. If not needed, set to `()`. - type GkrProof: Serializable + Deserializable + Send; - - /// A verifier for verifying GKR proofs. If not needed, set to `()`. - type GkrVerifier: GkrVerifier; + type PublicInputs: ToElements + Clone + Send + Sync; // REQUIRED METHODS // -------------------------------------------------------------------------------------------- @@ -217,7 +211,7 @@ pub trait Air: Send + Sync { fn new(trace_info: TraceInfo, pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self; /// Returns context for this instance of the computation. - fn context(&self) -> &AirContext; + fn context(&self) -> &AirContext; /// Evaluates transition constraints over the specified evaluation frame. /// @@ -306,16 +300,15 @@ pub trait Air: Send + Sync { Vec::new() } - // AUXILIARY PROOF VERIFIER + // LOGUP-GKR EVALUATOR // -------------------------------------------------------------------------------------------- - /// Returns the [`GkrVerifier`] to be used to verify the GKR proof. - /// - /// Leave unimplemented if the `Air` doesn't use a GKR proof. - fn get_gkr_proof_verifier>( + /// Returns the object needed for the LogUp-GKR argument. + fn get_logup_gkr_evaluator( &self, - ) -> Self::GkrVerifier { - unimplemented!("`get_auxiliary_proof_verifier()` must be implemented when the proof contains a GKR proof"); + ) -> impl LogUpGkrEvaluator + { + PhantomLogUpGkrEval::new() } // PROVIDED METHODS @@ -345,13 +338,16 @@ pub trait Air: Send + Sync { lagrange_composition_coefficients: LagrangeConstraintsCompositionCoefficients, lagrange_kernel_rand_elements: &LagrangeKernelRandElements, ) -> Option> { - self.context().lagrange_kernel_aux_column_idx().map(|col_idx| { - LagrangeKernelConstraints::new( + if self.context().logup_gkr_enabled() { + let col_idx = self.context().trace_info().aux_segment_width() - 1; + Some(LagrangeKernelConstraints::new( lagrange_composition_coefficients, lagrange_kernel_rand_elements, col_idx, - ) - }) + )) + } else { + None + } } /// Returns values for all periodic columns used in the computation. @@ -548,7 +544,7 @@ pub trait Air: Send + Sync { b_coefficients.push(public_coin.draw()?); } - let lagrange = if self.context().has_lagrange_kernel_aux_column() { + let lagrange = if self.context().logup_gkr_enabled() { let mut lagrange_kernel_t_coefficients = Vec::new(); for _ in 0..self.context().trace_len().ilog2() { lagrange_kernel_t_coefficients.push(public_coin.draw()?); @@ -564,10 +560,17 @@ pub trait Air: Send + Sync { None }; + let s_col = if self.context().logup_gkr_enabled() { + Some(public_coin.draw()?) + } else { + None + }; + Ok(ConstraintCompositionCoefficients { transition: t_coefficients, boundary: b_coefficients, lagrange, + s_col, }) } @@ -591,7 +594,13 @@ pub trait Air: Send + Sync { c_coefficients.push(public_coin.draw()?); } - let lagrange_cc = if self.context().has_lagrange_kernel_aux_column() { + let lagrange_cc = if self.context().logup_gkr_enabled() { + Some(public_coin.draw()?) + } else { + None + }; + + let s_col = if self.context().logup_gkr_enabled() { Some(public_coin.draw()?) } else { None @@ -601,6 +610,7 @@ pub trait Air: Send + Sync { trace: t_coefficients, constraints: c_coefficients, lagrange: lagrange_cc, + s_col, }) } } diff --git a/air/src/air/tests.rs b/air/src/air/tests.rs index e0063ed3b..5e9871ca5 100644 --- a/air/src/air/tests.rs +++ b/air/src/air/tests.rs @@ -9,8 +9,8 @@ use crypto::{hashers::Blake3_256, DefaultRandomCoin, RandomCoin}; use math::{fields::f64::BaseElement, get_power_series, polynom, FieldElement, StarkField}; use super::{ - Air, AirContext, Assertion, EvaluationFrame, ProofOptions, TraceInfo, - TransitionConstraintDegree, + logup_gkr::PhantomLogUpGkrEval, Air, AirContext, Assertion, EvaluationFrame, ProofOptions, + TraceInfo, TransitionConstraintDegree, }; use crate::FieldExtension; @@ -192,7 +192,7 @@ fn get_boundary_constraints() { // ================================================================================================ struct MockAir { - context: AirContext, + context: AirContext, assertions: Vec>, periodic_columns: Vec>, } @@ -225,8 +225,7 @@ impl MockAir { impl Air for MockAir { type BaseField = BaseElement; type PublicInputs = (); - type GkrProof = (); - type GkrVerifier = (); + //type LogUpGkrEvaluator = DummyLogUpGkrEval; fn new(trace_info: TraceInfo, _pub_inputs: (), _options: ProofOptions) -> Self { let num_assertions = trace_info.meta()[0] as usize; @@ -238,7 +237,7 @@ impl Air for MockAir { } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } @@ -257,6 +256,13 @@ impl Air for MockAir { _result: &mut [E], ) { } + + fn get_logup_gkr_evaluator( + &self, + ) -> impl super::LogUpGkrEvaluator + { + PhantomLogUpGkrEval::default() + } } // UTILITY FUNCTIONS @@ -266,11 +272,11 @@ pub fn build_context( trace_length: usize, trace_width: usize, num_assertions: usize, -) -> AirContext { +) -> AirContext { let options = ProofOptions::new(32, 8, 0, FieldExtension::None, 4, 31); let t_degrees = vec![TransitionConstraintDegree::new(2)]; let trace_info = TraceInfo::new(trace_width, trace_length); - AirContext::new(trace_info, t_degrees, num_assertions, options) + AirContext::new(trace_info, (), t_degrees, num_assertions, options) } pub fn build_prng() -> DefaultRandomCoin> { diff --git a/air/src/air/trace_info.rs b/air/src/air/trace_info.rs index 99ff4aa6d..44aa0a7ea 100644 --- a/air/src/air/trace_info.rs +++ b/air/src/air/trace_info.rs @@ -27,6 +27,7 @@ pub struct TraceInfo { num_aux_segment_rands: usize, trace_length: usize, trace_meta: Vec, + logup_gkr: bool, } impl TraceInfo { @@ -65,7 +66,7 @@ impl TraceInfo { /// * Length of `meta` is greater than 65535; pub fn with_meta(width: usize, length: usize, meta: Vec) -> Self { assert!(width > 0, "trace width must be greater than 0"); - Self::new_multi_segment(width, 0, 0, length, meta) + Self::new_multi_segment(width, 0, 0, length, meta, false) } /// Creates a new [TraceInfo] with main and auxiliary segments. @@ -90,6 +91,7 @@ impl TraceInfo { num_aux_segment_rands: usize, trace_length: usize, trace_meta: Vec, + logup_gkr: bool, ) -> Self { assert!( trace_length >= Self::MIN_TRACE_LENGTH, @@ -138,6 +140,7 @@ impl TraceInfo { num_aux_segment_rands, trace_length, trace_meta, + logup_gkr, } } @@ -146,9 +149,13 @@ impl TraceInfo { /// Returns the total number of columns in an execution trace. /// + /// When LogUp-GKR is enabled, we also account for two extra columns, in the auxiliary segment, + /// which are needed for implementing the univariate IOP for multi-linear evaluation in + /// https://eprint.iacr.org/2023/1284. + /// /// This is guaranteed to be between 1 and 255. pub fn width(&self) -> usize { - self.main_segment_width + self.aux_segment_width + self.main_segment_width + self.aux_segment_width + 2 * self.logup_gkr as usize } /// Returns execution trace length. @@ -171,13 +178,13 @@ impl TraceInfo { /// Returns the number of columns in the main segment of an execution trace. /// /// This is guaranteed to be between 1 and 255. - pub fn main_trace_width(&self) -> usize { + pub fn main_segment_width(&self) -> usize { self.main_segment_width } /// Returns the number of columns in the auxiliary segment of an execution trace. pub fn aux_segment_width(&self) -> usize { - self.aux_segment_width + self.aux_segment_width + 2 * self.logup_gkr as usize } /// Returns the total number of segments in an execution trace. @@ -198,9 +205,9 @@ impl TraceInfo { } } - /// Returns the number of columns in the auxiliary trace segment. - pub fn get_aux_segment_width(&self) -> usize { - self.aux_segment_width + /// Returns a boolean indicating whether LogUp-GKR is enabled. + pub fn logup_gkr_enabled(&self) -> bool { + self.logup_gkr } /// Returns the number of random elements needed to build all auxiliary columns, except for the @@ -264,6 +271,9 @@ impl Serializable for TraceInfo { // store trace meta target.write_u16(self.trace_meta.len() as u16); target.write_bytes(&self.trace_meta); + + // write bool indicating if LogUp-GKR is used + target.write_bool(self.logup_gkr); } } @@ -326,12 +336,16 @@ impl Deserializable for TraceInfo { vec![] }; + // read `logup_gkr` + let logup_gkr = source.read_bool()?; + Ok(Self::new_multi_segment( main_segment_width, aux_segment_width, num_aux_segment_rands, trace_length, trace_meta, + logup_gkr, )) } } @@ -387,6 +401,7 @@ mod tests { aux_rands, trace_length as usize, trace_meta, + false, ); assert_eq!(expected, info.to_elements()); diff --git a/air/src/air/transition/mod.rs b/air/src/air/transition/mod.rs index 60e641817..d29cbbb8b 100644 --- a/air/src/air/transition/mod.rs +++ b/air/src/air/transition/mod.rs @@ -46,7 +46,7 @@ impl TransitionConstraints { /// # Panics /// Panics if the number of transition constraints in the context does not match the number of /// provided composition coefficients. - pub fn new(context: &AirContext, composition_coefficients: &[E]) -> Self { + pub fn new

(context: &AirContext, composition_coefficients: &[E]) -> Self { assert_eq!( context.num_transition_constraints(), composition_coefficients.len(), diff --git a/air/src/lib.rs b/air/src/lib.rs index aaede0bda..2993306b9 100644 --- a/air/src/lib.rs +++ b/air/src/lib.rs @@ -44,7 +44,7 @@ mod air; pub use air::{ Air, AirContext, Assertion, AuxRandElements, BoundaryConstraint, BoundaryConstraintGroup, BoundaryConstraints, ConstraintCompositionCoefficients, ConstraintDivisor, - DeepCompositionCoefficients, EvaluationFrame, GkrRandElements, GkrVerifier, + DeepCompositionCoefficients, EvaluationFrame, GkrData, LagrangeConstraintsCompositionCoefficients, LagrangeKernelBoundaryConstraint, LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LagrangeKernelRandElements, LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, TraceInfo, diff --git a/air/src/proof/context.rs b/air/src/proof/context.rs index 83c2beece..73152709a 100644 --- a/air/src/proof/context.rs +++ b/air/src/proof/context.rs @@ -190,6 +190,7 @@ mod tests { aux_rands, trace_length, vec![], + false, ); let mut expected = trace_info.to_elements(); @@ -213,8 +214,14 @@ mod tests { fri_folding_factor as usize, fri_remainder_max_degree as usize, ); - let trace_info = - TraceInfo::new_multi_segment(main_width, aux_width, aux_rands, trace_length, vec![]); + let trace_info = TraceInfo::new_multi_segment( + main_width, + aux_width, + aux_rands, + trace_length, + vec![], + false, + ); let context = Context::new::(trace_info, options); assert_eq!(expected, context.to_elements()); } diff --git a/air/src/proof/ood_frame.rs b/air/src/proof/ood_frame.rs index d4b3f14ec..feab1b260 100644 --- a/air/src/proof/ood_frame.rs +++ b/air/src/proof/ood_frame.rs @@ -229,6 +229,8 @@ impl Deserializable for OodFrame { // OOD FRAME TRACE STATES // ================================================================================================ +/// Stores trace evaluations at an OOD point. +/// /// Stores the trace evaluations at `z` and `gz`, where `z` is a random Field element in /// `current_row` and `next_row`, respectively. If the Air contains a Lagrange kernel auxiliary /// column, then that column interpolated polynomial will be evaluated at `z`, `gz`, `g^2 z`, ... diff --git a/crypto/src/merkle/concurrent.rs b/crypto/src/merkle/concurrent.rs index 637bd51b5..7a3ba077f 100644 --- a/crypto/src/merkle/concurrent.rs +++ b/crypto/src/merkle/concurrent.rs @@ -18,9 +18,10 @@ pub const MIN_CONCURRENT_LEAVES: usize = 1024; // PUBLIC FUNCTIONS // ================================================================================================ -/// Builds all internal nodes of the Merkle using all available threads and stores the -/// results in a single vector such that root of the tree is at position 1, nodes immediately -/// under the root is at positions 2 and 3 etc. +/// Builds all internal nodes of the Merkle tree. +/// +/// This uses all available threads and stores the results in a single vector such that root of +/// the tree is at position 1, nodes immediately under the root is at positions 2 and 3 etc. pub fn build_merkle_nodes(leaves: &[H::Digest]) -> Vec { let n = leaves.len() / 2; diff --git a/examples/src/fibonacci/fib2/air.rs b/examples/src/fibonacci/fib2/air.rs index 9e5d75a48..4019ddcae 100644 --- a/examples/src/fibonacci/fib2/air.rs +++ b/examples/src/fibonacci/fib2/air.rs @@ -14,15 +14,13 @@ use crate::utils::are_equal; // ================================================================================================ pub struct FibAir { - context: AirContext, + context: AirContext, result: BaseElement, } impl Air for FibAir { type BaseField = BaseElement; type PublicInputs = BaseElement; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -30,12 +28,12 @@ impl Air for FibAir { let degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1)]; assert_eq!(TRACE_WIDTH, trace_info.width()); FibAir { - context: AirContext::new(trace_info, degrees, 3, options), + context: AirContext::new(trace_info, pub_inputs, degrees, 3, options), result: pub_inputs, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/fibonacci/fib8/air.rs b/examples/src/fibonacci/fib8/air.rs index 4d7aef9ba..17edc7970 100644 --- a/examples/src/fibonacci/fib8/air.rs +++ b/examples/src/fibonacci/fib8/air.rs @@ -15,15 +15,13 @@ use crate::utils::are_equal; // ================================================================================================ pub struct Fib8Air { - context: AirContext, + context: AirContext, result: BaseElement, } impl Air for Fib8Air { type BaseField = BaseElement; type PublicInputs = BaseElement; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -31,12 +29,12 @@ impl Air for Fib8Air { let degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1)]; assert_eq!(TRACE_WIDTH, trace_info.width()); Fib8Air { - context: AirContext::new(trace_info, degrees, 3, options), + context: AirContext::new(trace_info, pub_inputs, degrees, 3, options), result: pub_inputs, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/fibonacci/fib_small/air.rs b/examples/src/fibonacci/fib_small/air.rs index 66580c872..b48eb734b 100644 --- a/examples/src/fibonacci/fib_small/air.rs +++ b/examples/src/fibonacci/fib_small/air.rs @@ -14,15 +14,13 @@ use crate::utils::are_equal; // ================================================================================================ pub struct FibSmall { - context: AirContext, + context: AirContext, result: BaseElement, } impl Air for FibSmall { type BaseField = BaseElement; type PublicInputs = BaseElement; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -30,12 +28,12 @@ impl Air for FibSmall { let degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1)]; assert_eq!(TRACE_WIDTH, trace_info.width()); FibSmall { - context: AirContext::new(trace_info, degrees, 3, options), + context: AirContext::new(trace_info, pub_inputs, degrees, 3, options), result: pub_inputs, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/fibonacci/mulfib2/air.rs b/examples/src/fibonacci/mulfib2/air.rs index 3190d2e41..501adf6af 100644 --- a/examples/src/fibonacci/mulfib2/air.rs +++ b/examples/src/fibonacci/mulfib2/air.rs @@ -16,15 +16,13 @@ use crate::utils::are_equal; // ================================================================================================ pub struct MulFib2Air { - context: AirContext, + context: AirContext, result: BaseElement, } impl Air for MulFib2Air { type BaseField = BaseElement; type PublicInputs = BaseElement; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -32,12 +30,12 @@ impl Air for MulFib2Air { let degrees = vec![TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2)]; assert_eq!(TRACE_WIDTH, trace_info.width()); MulFib2Air { - context: AirContext::new(trace_info, degrees, 3, options), + context: AirContext::new(trace_info, pub_inputs, degrees, 3, options), result: pub_inputs, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/fibonacci/mulfib8/air.rs b/examples/src/fibonacci/mulfib8/air.rs index bbbe1dea0..c76f4f091 100644 --- a/examples/src/fibonacci/mulfib8/air.rs +++ b/examples/src/fibonacci/mulfib8/air.rs @@ -16,15 +16,13 @@ use crate::utils::are_equal; // ================================================================================================ pub struct MulFib8Air { - context: AirContext, + context: AirContext, result: BaseElement, } impl Air for MulFib8Air { type BaseField = BaseElement; type PublicInputs = BaseElement; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -41,12 +39,12 @@ impl Air for MulFib8Air { ]; assert_eq!(TRACE_WIDTH, trace_info.width()); MulFib8Air { - context: AirContext::new(trace_info, degrees, 3, options), + context: AirContext::new(trace_info, pub_inputs, degrees, 3, options), result: pub_inputs, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/lamport/aggregate/air.rs b/examples/src/lamport/aggregate/air.rs index 29b6e2372..57708fd74 100644 --- a/examples/src/lamport/aggregate/air.rs +++ b/examples/src/lamport/aggregate/air.rs @@ -38,7 +38,7 @@ impl ToElements for PublicInputs { } pub struct LamportAggregateAir { - context: AirContext, + context: AirContext, pub_keys: Vec<[BaseElement; 2]>, messages: Vec<[BaseElement; 2]>, } @@ -46,8 +46,6 @@ pub struct LamportAggregateAir { impl Air for LamportAggregateAir { type BaseField = BaseElement; type PublicInputs = PublicInputs; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -88,13 +86,13 @@ impl Air for LamportAggregateAir { ]; assert_eq!(TRACE_WIDTH, trace_info.width()); LamportAggregateAir { - context: AirContext::new(trace_info, degrees, 22, options), + context: AirContext::new(trace_info, pub_inputs.clone(), degrees, 22, options), pub_keys: pub_inputs.pub_keys, messages: pub_inputs.messages, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/lamport/threshold/air.rs b/examples/src/lamport/threshold/air.rs index 41983c743..b68a2a24d 100644 --- a/examples/src/lamport/threshold/air.rs +++ b/examples/src/lamport/threshold/air.rs @@ -22,7 +22,7 @@ const TWO: BaseElement = BaseElement::new(2); // THRESHOLD LAMPORT PLUS SIGNATURE AIR // ================================================================================================ -#[derive(Clone)] +#[derive(Clone, Default)] pub struct PublicInputs { pub pub_key_root: [BaseElement; 2], pub num_pub_keys: usize, @@ -41,7 +41,7 @@ impl ToElements for PublicInputs { } pub struct LamportThresholdAir { - context: AirContext, + context: AirContext, pub_key_root: [BaseElement; 2], num_pub_keys: usize, num_signatures: usize, @@ -51,8 +51,6 @@ pub struct LamportThresholdAir { impl Air for LamportThresholdAir { type BaseField = BaseElement; type PublicInputs = PublicInputs; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -99,7 +97,7 @@ impl Air for LamportThresholdAir { ]; assert_eq!(TRACE_WIDTH, trace_info.width()); LamportThresholdAir { - context: AirContext::new(trace_info, degrees, 26, options), + context: AirContext::new(trace_info, pub_inputs.clone(), degrees, 26, options), pub_key_root: pub_inputs.pub_key_root, num_pub_keys: pub_inputs.num_pub_keys, num_signatures: pub_inputs.num_signatures, @@ -244,7 +242,7 @@ impl Air for LamportThresholdAir { result } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } } diff --git a/examples/src/merkle/air.rs b/examples/src/merkle/air.rs index e0c8b177c..5d38397ff 100644 --- a/examples/src/merkle/air.rs +++ b/examples/src/merkle/air.rs @@ -14,6 +14,7 @@ use crate::utils::{are_equal, is_binary, is_zero, not, EvaluationResult}; // MERKLE PATH VERIFICATION AIR // ================================================================================================ +#[derive(Clone)] pub struct PublicInputs { pub tree_root: [BaseElement; 2], } @@ -25,15 +26,13 @@ impl ToElements for PublicInputs { } pub struct MerkleAir { - context: AirContext, + context: AirContext, tree_root: [BaseElement; 2], } impl Air for MerkleAir { type BaseField = BaseElement; type PublicInputs = PublicInputs; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -49,12 +48,12 @@ impl Air for MerkleAir { ]; assert_eq!(TRACE_WIDTH, trace_info.width()); MerkleAir { - context: AirContext::new(trace_info, degrees, 4, options), + context: AirContext::new(trace_info, pub_inputs.clone(), degrees, 4, options), tree_root: pub_inputs.tree_root, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/rescue/air.rs b/examples/src/rescue/air.rs index a9d3d5ebb..09bf9c450 100644 --- a/examples/src/rescue/air.rs +++ b/examples/src/rescue/air.rs @@ -37,6 +37,7 @@ const CYCLE_MASK: [BaseElement; CYCLE_LENGTH] = [ // RESCUE AIR // ================================================================================================ +#[derive(Clone)] pub struct PublicInputs { pub seed: [BaseElement; 2], pub result: [BaseElement; 2], @@ -51,7 +52,7 @@ impl ToElements for PublicInputs { } pub struct RescueAir { - context: AirContext, + context: AirContext, seed: [BaseElement; 2], result: [BaseElement; 2], } @@ -59,8 +60,6 @@ pub struct RescueAir { impl Air for RescueAir { type BaseField = BaseElement; type PublicInputs = PublicInputs; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -73,13 +72,13 @@ impl Air for RescueAir { ]; assert_eq!(TRACE_WIDTH, trace_info.width()); RescueAir { - context: AirContext::new(trace_info, degrees, 4, options), + context: AirContext::new(trace_info, pub_inputs.clone(), degrees, 4, options), seed: pub_inputs.seed, result: pub_inputs.result, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/rescue_raps/air.rs b/examples/src/rescue_raps/air.rs index 6fb5321b1..694e189bc 100644 --- a/examples/src/rescue_raps/air.rs +++ b/examples/src/rescue_raps/air.rs @@ -41,6 +41,7 @@ const CYCLE_MASK: [BaseElement; CYCLE_LENGTH] = [ // RESCUE AIR // ================================================================================================ +#[derive(Clone)] pub struct PublicInputs { pub result: [[BaseElement; 2]; 2], } @@ -52,15 +53,13 @@ impl ToElements for PublicInputs { } pub struct RescueRapsAir { - context: AirContext, + context: AirContext, result: [[BaseElement; 2]; 2], } impl Air for RescueRapsAir { type BaseField = BaseElement; type PublicInputs = PublicInputs; - type GkrProof = (); - type GkrVerifier = (); // CONSTRUCTOR // -------------------------------------------------------------------------------------------- @@ -76,18 +75,18 @@ impl Air for RescueRapsAir { RescueRapsAir { context: AirContext::new_multi_segment( trace_info, + pub_inputs.clone(), main_degrees, aux_degrees, 8, 2, - None, options, ), result: pub_inputs.result, } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } diff --git a/examples/src/rescue_raps/custom_trace_table.rs b/examples/src/rescue_raps/custom_trace_table.rs index 063d509a4..f6f9d075b 100644 --- a/examples/src/rescue_raps/custom_trace_table.rs +++ b/examples/src/rescue_raps/custom_trace_table.rs @@ -89,7 +89,7 @@ impl RapTraceTable { let columns = unsafe { (0..width).map(|_| uninit_vector(length)).collect() }; Self { - info: TraceInfo::new_multi_segment(width, 3, 3, length, meta), + info: TraceInfo::new_multi_segment(width, 3, 3, length, meta, false), trace: ColMatrix::new(columns), } } @@ -113,7 +113,7 @@ impl RapTraceTable { I: Fn(&mut [B]), U: Fn(usize, &mut [B]), { - let mut state = vec![B::ZERO; self.info.main_trace_width()]; + let mut state = vec![B::ZERO; self.info.main_segment_width()]; init(&mut state); self.update_row(0, &state); @@ -133,7 +133,7 @@ impl RapTraceTable { /// Returns the number of columns in this execution trace. pub fn width(&self) -> usize { - self.info.main_trace_width() + self.info.main_segment_width() } /// Returns value of the cell in the specified column at the specified row of this trace. diff --git a/examples/src/rescue_raps/prover.rs b/examples/src/rescue_raps/prover.rs index 7adee9bbb..6e50f1572 100644 --- a/examples/src/rescue_raps/prover.rs +++ b/examples/src/rescue_raps/prover.rs @@ -139,16 +139,11 @@ where DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients) } - fn build_aux_trace( - &self, - trace: &Self::Trace, - aux_rand_elements: &AuxRandElements, - ) -> ColMatrix + fn build_aux_trace(&self, trace: &Self::Trace, aux_rand_elements: &[E]) -> ColMatrix where E: FieldElement, { let main_trace = trace.main_segment(); - let rand_elements = aux_rand_elements.rand_elements(); let mut current_row = unsafe { uninit_vector(main_trace.num_cols()) }; let mut next_row = unsafe { uninit_vector(main_trace.num_cols()) }; @@ -157,10 +152,10 @@ where // Columns storing the copied values for the permutation argument are not necessary, but // help understanding the construction of RAPs and are kept for illustrative purposes. - aux_columns[0][0] = - rand_elements[0] * current_row[0].into() + rand_elements[1] * current_row[1].into(); - aux_columns[1][0] = - rand_elements[0] * current_row[4].into() + rand_elements[1] * current_row[5].into(); + aux_columns[0][0] = aux_rand_elements[0] * current_row[0].into() + + aux_rand_elements[1] * current_row[1].into(); + aux_columns[1][0] = aux_rand_elements[0] * current_row[4].into() + + aux_rand_elements[1] * current_row[5].into(); // Permutation argument column aux_columns[2][0] = E::ONE; @@ -172,14 +167,16 @@ where main_trace.read_row_into(index, &mut current_row); main_trace.read_row_into(index + 1, &mut next_row); - aux_columns[0][index] = rand_elements[0] * (next_row[0] - current_row[0]).into() - + rand_elements[1] * (next_row[1] - current_row[1]).into(); - aux_columns[1][index] = rand_elements[0] * (next_row[4] - current_row[4]).into() - + rand_elements[1] * (next_row[5] - current_row[5]).into(); + aux_columns[0][index] = aux_rand_elements[0] + * (next_row[0] - current_row[0]).into() + + aux_rand_elements[1] * (next_row[1] - current_row[1]).into(); + aux_columns[1][index] = aux_rand_elements[0] + * (next_row[4] - current_row[4]).into() + + aux_rand_elements[1] * (next_row[5] - current_row[5]).into(); } - let num = aux_columns[0][index - 1] + rand_elements[2]; - let denom = aux_columns[1][index - 1] + rand_elements[2]; + let num = aux_columns[0][index - 1] + aux_rand_elements[2]; + let denom = aux_columns[1][index - 1] + aux_rand_elements[2]; aux_columns[2][index] = aux_columns[2][index - 1] * num * denom.inv(); } diff --git a/examples/src/utils/rescue.rs b/examples/src/utils/rescue.rs index e09cb094e..be297fcf3 100644 --- a/examples/src/utils/rescue.rs +++ b/examples/src/utils/rescue.rs @@ -21,6 +21,8 @@ pub const RATE_WIDTH: usize = 4; /// Two elements (32-bytes) are returned as digest. const DIGEST_SIZE: usize = 2; +/// Number of rounds used in Rescue. +/// /// The number of rounds is set to 7 to provide 128-bit security level with 40% security margin; /// computed using algorithm 7 from /// security margin here differs from Rescue Prime specification which suggests 50% security diff --git a/examples/src/vdf/exempt/air.rs b/examples/src/vdf/exempt/air.rs index 9254e4e0a..015778459 100644 --- a/examples/src/vdf/exempt/air.rs +++ b/examples/src/vdf/exempt/air.rs @@ -29,7 +29,7 @@ impl ToElements for VdfInputs { // ================================================================================================ pub struct VdfAir { - context: AirContext, + context: AirContext, seed: BaseElement, result: BaseElement, } @@ -37,16 +37,14 @@ pub struct VdfAir { impl Air for VdfAir { type BaseField = BaseElement; type PublicInputs = VdfInputs; - type GkrProof = (); - type GkrVerifier = (); fn new(trace_info: TraceInfo, pub_inputs: VdfInputs, options: ProofOptions) -> Self { let degrees = vec![TransitionConstraintDegree::new(3)]; assert_eq!(TRACE_WIDTH, trace_info.width()); // make sure the last two rows are excluded from transition constraints as we populate // values in the last row with garbage - let context = - AirContext::new(trace_info, degrees, 2, options).set_num_transition_exemptions(2); + let context = AirContext::new(trace_info, pub_inputs.clone(), degrees, 2, options) + .set_num_transition_exemptions(2); Self { context, seed: pub_inputs.seed, @@ -76,7 +74,7 @@ impl Air for VdfAir { ] } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } } diff --git a/examples/src/vdf/regular/air.rs b/examples/src/vdf/regular/air.rs index b434c1478..bec2ccb3c 100644 --- a/examples/src/vdf/regular/air.rs +++ b/examples/src/vdf/regular/air.rs @@ -29,7 +29,7 @@ impl ToElements for VdfInputs { // ================================================================================================ pub struct VdfAir { - context: AirContext, + context: AirContext, seed: BaseElement, result: BaseElement, } @@ -37,14 +37,12 @@ pub struct VdfAir { impl Air for VdfAir { type BaseField = BaseElement; type PublicInputs = VdfInputs; - type GkrProof = (); - type GkrVerifier = (); fn new(trace_info: TraceInfo, pub_inputs: VdfInputs, options: ProofOptions) -> Self { let degrees = vec![TransitionConstraintDegree::new(3)]; assert_eq!(TRACE_WIDTH, trace_info.width()); Self { - context: AirContext::new(trace_info, degrees, 2, options), + context: AirContext::new(trace_info, pub_inputs.clone(), degrees, 2, options), seed: pub_inputs.seed, result: pub_inputs.result, } @@ -67,7 +65,7 @@ impl Air for VdfAir { vec![Assertion::single(0, 0, self.seed), Assertion::single(0, last_step, self.result)] } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } } diff --git a/math/src/field/f64/mod.rs b/math/src/field/f64/mod.rs index 119676076..64c637c0a 100644 --- a/math/src/field/f64/mod.rs +++ b/math/src/field/f64/mod.rs @@ -3,9 +3,10 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -//! An implementation of a 64-bit STARK-friendly prime field with modulus $2^{64} - 2^{32} + 1$ -//! using Montgomery representation. -//! Our implementation follows and is constant-time. +//! An implementation of a 64-bit STARK-friendly prime field with modulus $2^{64} - 2^{32} + 1$. +//! +//! Our implementation uses Montgomery representation and follows +//! and is constant-time. //! //! This field supports very fast modular arithmetic and has a number of other attractive //! properties, including: diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 36272766f..6fef7f90f 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -35,6 +35,8 @@ crypto = { version = "0.9", path = "../crypto", package = "winter-crypto", defau fri = { version = "0.9", path = '../fri', package = "winter-fri", default-features = false } math = { version = "0.9", path = "../math", package = "winter-math", default-features = false } maybe_async = { path = "../utils/maybe_async" , package = "winter-maybe-async" } +sumcheck = { version = "0.1", path = "../sumcheck", package = "winter-sumcheck", default-features = false } +thiserror = { version = "1.0", git = "https://github.com/bitwalker/thiserror", branch = "no-std", default-features = false } tracing = { version = "0.1", default-features = false, features = ["attributes"]} utils = { version = "0.9", path = "../utils/core", package = "winter-utils", default-features = false } diff --git a/prover/benches/lagrange_kernel.rs b/prover/benches/lagrange_kernel.rs index 7ee8ab3c3..348554806 100644 --- a/prover/benches/lagrange_kernel.rs +++ b/prover/benches/lagrange_kernel.rs @@ -7,15 +7,14 @@ use std::time::Duration; use air::{ Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients, - EvaluationFrame, FieldExtension, GkrRandElements, LagrangeKernelRandElements, ProofOptions, - TraceInfo, TransitionConstraintDegree, + EvaluationFrame, FieldExtension, ProofOptions, TraceInfo, TransitionConstraintDegree, }; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; -use crypto::{hashers::Blake3_256, DefaultRandomCoin, MerkleTree, RandomCoin}; +use crypto::{hashers::Blake3_256, DefaultRandomCoin, MerkleTree}; use math::{fields::f64::BaseElement, ExtensionOf, FieldElement}; use winter_prover::{ - matrix::ColMatrix, DefaultConstraintEvaluator, DefaultTraceLde, Prover, ProverGkrProof, - StarkDomain, Trace, TracePolyTable, + matrix::ColMatrix, DefaultConstraintEvaluator, DefaultTraceLde, Prover, StarkDomain, Trace, + TracePolyTable, }; const TRACE_LENS: [usize; 2] = [2_usize.pow(16), 2_usize.pow(20)]; @@ -61,7 +60,7 @@ impl LagrangeTrace { Self { main_trace: ColMatrix::new(vec![main_trace_col]), - info: TraceInfo::new_multi_segment(1, aux_segment_width, 0, trace_len, vec![]), + info: TraceInfo::new_multi_segment(1, aux_segment_width, 0, trace_len, vec![], false), } } @@ -94,31 +93,28 @@ impl Trace for LagrangeTrace { // ================================================================================================= struct LagrangeKernelAir { - context: AirContext, + context: AirContext, } impl Air for LagrangeKernelAir { type BaseField = BaseElement; - type GkrProof = (); - type GkrVerifier = (); - type PublicInputs = (); fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { Self { context: AirContext::new_multi_segment( trace_info, + _pub_inputs, vec![TransitionConstraintDegree::new(1)], vec![TransitionConstraintDegree::new(1)], 1, 1, - Some(0), options, ), } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } @@ -221,42 +217,14 @@ impl Prover for LagrangeProver { DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients) } - fn generate_gkr_proof( - &self, - main_trace: &Self::Trace, - public_coin: &mut Self::RandomCoin, - ) -> (ProverGkrProof, GkrRandElements) - where - E: FieldElement, - { - let main_trace = main_trace.main_segment(); - let lagrange_kernel_rand_elements = { - let log_trace_len = main_trace.num_rows().ilog2() as usize; - let mut rand_elements = Vec::with_capacity(log_trace_len); - for _ in 0..log_trace_len { - rand_elements.push(public_coin.draw().unwrap()); - } - - LagrangeKernelRandElements::new(rand_elements) - }; - - ((), GkrRandElements::new(lagrange_kernel_rand_elements, Vec::new())) - } - - fn build_aux_trace( - &self, - main_trace: &Self::Trace, - aux_rand_elements: &AuxRandElements, - ) -> ColMatrix + fn build_aux_trace(&self, main_trace: &Self::Trace, aux_rand_elements: &[E]) -> ColMatrix where E: FieldElement, { let main_trace = main_trace.main_segment(); let mut columns = Vec::new(); - let lagrange_kernel_rand_elements = aux_rand_elements - .lagrange() - .expect("expected lagrange kernel random elements to be present."); + let lagrange_kernel_rand_elements = aux_rand_elements; // first build the Lagrange kernel column { diff --git a/prover/src/constraints/evaluator/default.rs b/prover/src/constraints/evaluator/default.rs index 8f96c7dcd..ea02b41d4 100644 --- a/prover/src/constraints/evaluator/default.rs +++ b/prover/src/constraints/evaluator/default.rs @@ -158,7 +158,7 @@ where &composition_coefficients.boundary, ); - let lagrange_constraints_evaluator = if air.context().has_lagrange_kernel_aux_column() { + let lagrange_constraints_evaluator = if air.context().logup_gkr_enabled() { let aux_rand_elements = aux_rand_elements.as_ref().expect("expected aux rand elements to be present"); let lagrange_rand_elements = aux_rand_elements @@ -198,7 +198,7 @@ where fragment: &mut EvaluationTableFragment, ) { // initialize buffers to hold trace values and evaluation results at each step; - let mut main_frame = EvaluationFrame::new(trace.trace_info().main_trace_width()); + let mut main_frame = EvaluationFrame::new(trace.trace_info().main_segment_width()); let mut evaluations = vec![E::ZERO; fragment.num_columns()]; let mut t_evaluations = vec![E::BaseField::ZERO; self.num_main_transition_constraints()]; @@ -249,7 +249,7 @@ where fragment: &mut EvaluationTableFragment, ) { // initialize buffers to hold trace values and evaluation results at each step - let mut main_frame = EvaluationFrame::new(trace.trace_info().main_trace_width()); + let mut main_frame = EvaluationFrame::new(trace.trace_info().main_segment_width()); let mut aux_frame = EvaluationFrame::new(trace.trace_info().aux_segment_width()); let mut tm_evaluations = vec![E::BaseField::ZERO; self.num_main_transition_constraints()]; let mut ta_evaluations = vec![E::ZERO; self.num_aux_transition_constraints()]; diff --git a/prover/src/constraints/evaluator/periodic_table.rs b/prover/src/constraints/evaluator/periodic_table.rs index ec72aa766..f1fc751e0 100644 --- a/prover/src/constraints/evaluator/periodic_table.rs +++ b/prover/src/constraints/evaluator/periodic_table.rs @@ -94,7 +94,7 @@ mod tests { use air::Air; use math::{ - fields::f128::BaseElement, get_power_series_with_offset, polynom, FieldElement, StarkField, + fields::f64::BaseElement, get_power_series_with_offset, polynom, FieldElement, StarkField, }; use crate::tests::MockAir; @@ -104,8 +104,8 @@ mod tests { let trace_length = 32; // instantiate AIR with 2 periodic columns - let col1 = vec![1u128, 2].into_iter().map(BaseElement::new).collect::>(); - let col2 = vec![3u128, 4, 5, 6].into_iter().map(BaseElement::new).collect::>(); + let col1 = vec![1u64, 2].into_iter().map(BaseElement::new).collect::>(); + let col2 = vec![3u64, 4, 5, 6].into_iter().map(BaseElement::new).collect::>(); let air = MockAir::with_periodic_columns(vec![col1, col2], trace_length); // build a table of periodic values diff --git a/prover/src/errors.rs b/prover/src/errors.rs index a0d01a233..3a14de46e 100644 --- a/prover/src/errors.rs +++ b/prover/src/errors.rs @@ -21,6 +21,8 @@ pub enum ProverError { /// This error occurs when the base field specified by the AIR does not support field extension /// of degree specified by proof options. UnsupportedFieldExtension(usize), + /// This error occurs when generation of the GKR proof for the LogUp relation fails. + FailedToGenerateGkrProof, } impl fmt::Display for ProverError { @@ -36,6 +38,9 @@ impl fmt::Display for ProverError { Self::UnsupportedFieldExtension(degree) => { write!(f, "field extension of degree {degree} is not supported for the specified base field") } + ProverError::FailedToGenerateGkrProof => { + write!(f, "Failed to generate the GKR proof for the LogUp relation") + } } } } diff --git a/prover/src/lib.rs b/prover/src/lib.rs index ac0e82be2..703f19d8c 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -48,7 +48,7 @@ pub use air::{ EvaluationFrame, FieldExtension, LagrangeKernelRandElements, ProofOptions, TraceInfo, TransitionConstraintDegree, }; -use air::{AuxRandElements, GkrRandElements}; +use air::{AuxRandElements, GkrData, LogUpGkrEvaluator}; pub use crypto; use crypto::{ElementHasher, RandomCoin, VectorCommitment}; use fri::FriProver; @@ -58,6 +58,7 @@ use math::{ fields::{CubeExtension, QuadExtension}, ExtensibleField, FieldElement, StarkField, ToElements, }; +use sumcheck::FinalOpeningClaim; use tracing::{event, info_span, instrument, Level}; pub use utils::{ iterators, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, @@ -86,6 +87,9 @@ pub use trace::{ TraceTableFragment, }; +mod logup_gkr; +use logup_gkr::{build_lagrange_column, build_s_column, prove_gkr}; + mod channel; use channel::ProverChannel; @@ -101,9 +105,6 @@ pub mod tests; // this segment width seems to give the best performance for small fields (i.e., 64 bits) const DEFAULT_SEGMENT_WIDTH: usize = 8; -/// Accesses the `GkrProof` type in a [`Prover`]. -pub type ProverGkrProof

= <

::Air as Air>::GkrProof; - /// Defines a STARK prover for a computation. /// /// A STARK prover can be used to generate STARK proofs. The prover contains definitions of a @@ -201,28 +202,10 @@ pub trait Prover { // PROVIDED METHODS // -------------------------------------------------------------------------------------------- - /// Builds the GKR proof. If the [`Air`] doesn't use a GKR proof, leave unimplemented. - #[allow(unused_variables)] - #[maybe_async] - fn generate_gkr_proof( - &self, - main_trace: &Self::Trace, - public_coin: &mut Self::RandomCoin, - ) -> (ProverGkrProof, GkrRandElements) - where - E: FieldElement, - { - unimplemented!("`Prover::generate_gkr_proof` needs to be implemented when the auxiliary trace has a Lagrange kernel column.") - } - /// Builds and returns the auxiliary trace. #[allow(unused_variables)] #[maybe_async] - fn build_aux_trace( - &self, - main_trace: &Self::Trace, - aux_rand_elements: &AuxRandElements, - ) -> ColMatrix + fn build_aux_trace(&self, main_trace: &Self::Trace, aux_rand_elements: &[E]) -> ColMatrix where E: FieldElement, { @@ -241,7 +224,6 @@ pub trait Prover { fn prove(&self, trace: Self::Trace) -> Result where ::PublicInputs: Send, - ::GkrProof: Send, { // figure out which version of the generic proof generation procedure to run. this is a sort // of static dispatch for selecting two generic parameter: extension field and hash @@ -275,7 +257,6 @@ pub trait Prover { where E: FieldElement, ::PublicInputs: Send, - ::GkrProof: Send, { // 0 ----- instantiate AIR and prover channel --------------------------------------------- @@ -314,27 +295,40 @@ pub trait Prover { // build the auxiliary trace segment, and append the resulting segments to trace commitment // and trace polynomial table structs let aux_trace_with_metadata = if air.trace_info().is_multi_segment() { - let (gkr_proof, aux_rand_elements) = if air.context().has_lagrange_kernel_aux_column() { - let (gkr_proof, gkr_rand_elements) = - maybe_await!(self.generate_gkr_proof(&trace, channel.public_coin())); - - let rand_elements = air - .get_aux_rand_elements(channel.public_coin()) - .expect("failed to draw random elements for the auxiliary trace segment"); - - let aux_rand_elements = - AuxRandElements::new_with_gkr(rand_elements, gkr_rand_elements); - - (Some(gkr_proof), aux_rand_elements) + // build the auxiliary segment without the LogUp-GKR related part + let aux_rand_elements = air + .get_aux_rand_elements(channel.public_coin()) + .expect("failed to draw random elements for the auxiliary trace segment"); + let mut aux_trace = maybe_await!(self.build_aux_trace(&trace, &aux_rand_elements)); + + // build the LogUp-GKR related section of the auxiliary segment, if any. This will also + // build an object containing randomness and data related to the LogUp-GKR section of + // the auxiliary trace segment. + let (gkr_proof, gkr_rand_elements) = if air.context().logup_gkr_enabled() { + let gkr_proof = + prove_gkr(&trace, &air.get_logup_gkr_evaluator(), channel.public_coin()) + .map_err(|_| ProverError::FailedToGenerateGkrProof)?; + + let FinalOpeningClaim { eval_point, openings } = + gkr_proof.get_final_opening_claim(); + + let gkr_data = air + .get_logup_gkr_evaluator() + .generate_univariate_iop_for_multi_linear_opening_data( + openings, + eval_point, + channel.public_coin(), + ); + + // add the extra columns required for LogUp-GKR + maybe_await!(build_logup_gkr_columns(&air, &trace, &mut aux_trace, &gkr_data)); + + (Some(gkr_proof), Some(gkr_data)) } else { - let rand_elements = air - .get_aux_rand_elements(channel.public_coin()) - .expect("failed to draw random elements for the auxiliary trace segment"); - - (None, AuxRandElements::new(rand_elements)) + (None, None) }; - - let aux_trace = maybe_await!(self.build_aux_trace(&trace, &aux_rand_elements)); + // build the set of all random values associated to the auxiliary segment + let aux_rand_elements = AuxRandElements::new(aux_rand_elements, gkr_rand_elements); // commit to the auxiliary trace segment let aux_segment_polys = { @@ -616,3 +610,26 @@ pub trait Prover { (constraint_commitment, composition_poly) } } + +/// Builds and appends to the auxiliary segment two additional columns needed for implementing +/// the univariate IOP for multi-linear evaluation of Section 5 in [1]. +/// +/// [1]: https://eprint.iacr.org/2023/1284 +#[maybe_async] +fn build_logup_gkr_columns( + air: &A, + main_trace: &T, + aux_trace: &mut ColMatrix, + gkr_data: &GkrData, +) where + E: FieldElement, + A: Air, + T: Trace, +{ + let evaluator = air.get_logup_gkr_evaluator(); + let lagrange_col = build_lagrange_column(&gkr_data.lagrange_kernel_eval_point); + let s_col = build_s_column(main_trace, gkr_data, &evaluator, &lagrange_col); + + aux_trace.merge_column(s_col); + aux_trace.merge_column(lagrange_col); +} diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs new file mode 100644 index 000000000..643258ee2 --- /dev/null +++ b/prover/src/logup_gkr/mod.rs @@ -0,0 +1,403 @@ +use alloc::vec::Vec; +use core::ops::Add; + +use air::{EvaluationFrame, GkrData, LogUpGkrEvaluator}; +use math::FieldElement; +use sumcheck::{EqFunction, MultiLinearPoly, SumCheckProverError}; +use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; + +use crate::Trace; + +mod prover; +pub use prover::prove_gkr; + +// EVALUATED CIRCUIT +// ================================================================================================ + +/// Evaluation of a layered circuit for computing a sum of fractions. +/// +/// The circuit computes a sum of fractions based on the formula a / c + b / d = (a * d + b * c) / +/// (c * d) which defines a "gate" ((a, b), (c, d)) --> (a * d + b * c, c * d) upon which the +/// [`EvaluatedCircuit`] is built. Due to the uniformity of the circuit, each of the circuit +/// layers collect all the: +/// +/// 1. `a`'s into a [`MultiLinearPoly`] called `left_numerators`. +/// 2. `b`'s into a [`MultiLinearPoly`] called `right_numerators`. +/// 3. `c`'s into a [`MultiLinearPoly`] called `left_denominators`. +/// 4. `d`'s into a [`MultiLinearPoly`] called `right_denominators`. +/// +/// The relation between two subsequent layers is given by the formula +/// +/// p_0[layer + 1](x_0, x_1, ..., x_{ν - 2}) = p_0[layer](x_0, x_1, ..., x_{ν - 2}, 0) * +/// q_1[layer](x_0, x_1, ..., x_{ν - 2}, 0) +/// + p_1[layer](x_0, x_1, ..., x_{ν - 2}, 0) * q_0[layer](x_0, +/// x_1, ..., x_{ν - 2}, 0) +/// +/// p_1[layer + 1](x_0, x_1, ..., x_{ν - 2}) = p_0[layer](x_0, x_1, ..., x_{ν - 2}, 1) * +/// q_1[layer](x_0, x_1, ..., x_{ν - 2}, 1) +/// + p_1[layer](x_0, x_1, ..., x_{ν - 2}, 1) * q_0[layer](x_0, +/// x_1, ..., x_{ν - 2}, 1) +/// +/// and +/// +/// q_0[layer + 1](x_0, x_1, ..., x_{ν - 2}) = q_0[layer](x_0, x_1, ..., x_{ν - 2}, 0) * +/// q_1[layer](x_0, x_1, ..., x_{ν - 1}, 0) +/// q_1[layer + 1](x_0, x_1, ..., x_{ν - 2}) = q_0[layer](x_0, x_1, ..., x_{ν - 2}, 1) * +/// q_1[layer](x_0, x_1, ..., x_{ν - 1}, 1) +/// +/// This logic is encoded in [`CircuitWire`]. +/// +/// This means that layer ν will be the output layer and will consist of four values +/// (p_0[ν - 1], p_1[ν - 1], p_0[ν - 1], p_1[ν - 1]) ∈ 𝔽^ν. +pub struct EvaluatedCircuit { + layer_polys: Vec>, +} + +impl EvaluatedCircuit { + /// Creates a new [`EvaluatedCircuit`] by evaluating the circuit where the input layer is + /// defined from the main trace columns. + pub fn new( + main_trace_columns: &impl Trace, + evaluator: &impl LogUpGkrEvaluator, + log_up_randomness: &[E], + ) -> Result { + let mut layer_polys = Vec::new(); + + let mut current_layer = + Self::generate_input_layer(main_trace_columns, evaluator, log_up_randomness); + while current_layer.num_wires() > 1 { + let next_layer = Self::compute_next_layer(¤t_layer); + + layer_polys.push(CircuitLayerPolys::from_circuit_layer(current_layer)); + + current_layer = next_layer; + } + + Ok(Self { layer_polys }) + } + + /// Returns all layers of the evaluated circuit, starting from the input layer. + /// + /// Note that the return type is a slice of [`CircuitLayerPolys`] as opposed to + /// [`CircuitLayer`], since the evaluated layers are stored in a representation which can be + /// proved using GKR. + pub fn layers(self) -> Vec> { + self.layer_polys + } + + /// Returns the numerator/denominator polynomials representing the output layer of the circuit. + pub fn output_layer(&self) -> &CircuitLayerPolys { + self.layer_polys.last().expect("circuit has at least one layer") + } + + /// Evaluates the output layer at `query`, where the numerators of the output layer are treated + /// as evaluations of a multilinear polynomial, and similarly for the denominators. + pub fn evaluate_output_layer(&self, query: E) -> (E, E) { + let CircuitLayerPolys { numerators, denominators } = self.output_layer(); + + (numerators.evaluate(&[query]), denominators.evaluate(&[query])) + } + + // HELPERS + // ------------------------------------------------------------------------------------------- + + /// Generates the input layer of the circuit from the main trace columns and some randomness + /// provided by the verifier. + fn generate_input_layer( + main_trace: &impl Trace, + evaluator: &impl LogUpGkrEvaluator, + log_up_randomness: &[E], + ) -> CircuitLayer { + let num_fractions = evaluator.get_num_fractions(); + let mut input_layer_wires = + Vec::with_capacity(main_trace.main_segment().num_rows() * num_fractions); + let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols()); + + let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; + let mut numerators = vec![E::ZERO; num_fractions]; + let mut denominators = vec![E::ZERO; num_fractions]; + for i in 0..main_trace.main_segment().num_rows() { + let wires_from_trace_row = { + main_trace.read_main_frame(i, &mut main_frame); + + evaluator.build_query(&main_frame, &[], &mut query); + + evaluator.evaluate_query( + &query, + log_up_randomness, + &mut numerators, + &mut denominators, + ); + let input_gates_values: Vec> = numerators + .iter() + .zip(denominators.iter()) + .map(|(numerator, denominator)| CircuitWire::new(*numerator, *denominator)) + .collect(); + input_gates_values + }; + + input_layer_wires.extend(wires_from_trace_row); + } + + CircuitLayer::new(input_layer_wires) + } + + /// Computes the subsequent layer of the circuit from a given layer. + fn compute_next_layer(prev_layer: &CircuitLayer) -> CircuitLayer { + let next_layer_wires = prev_layer + .wires() + .chunks_exact(2) + .map(|input_wires| { + let left_input_wire = input_wires[0]; + let right_input_wire = input_wires[1]; + + // output wire + left_input_wire + right_input_wire + }) + .collect(); + + CircuitLayer::new(next_layer_wires) + } +} + +// CIRCUIT LAYER POLYS +// =============================================================================================== + +/// Holds a layer of an [`EvaluatedCircuit`] in a representation amenable to proving circuit +/// evaluation using GKR. +#[derive(Clone, Debug)] +pub struct CircuitLayerPolys { + pub numerators: MultiLinearPoly, + pub denominators: MultiLinearPoly, +} + +impl CircuitLayerPolys +where + E: FieldElement, +{ + pub fn from_circuit_layer(layer: CircuitLayer) -> Self { + Self::from_wires(layer.wires) + } + + pub fn from_wires(wires: Vec>) -> Self { + let mut numerators = Vec::new(); + let mut denominators = Vec::new(); + + for wire in wires { + numerators.push(wire.numerator); + denominators.push(wire.denominator); + } + + Self { + numerators: MultiLinearPoly::from_evaluations(numerators), + denominators: MultiLinearPoly::from_evaluations(denominators), + } + } + + fn into_numerators_denominators(self) -> (MultiLinearPoly, MultiLinearPoly) { + (self.numerators, self.denominators) + } +} + +impl Serializable for CircuitLayerPolys +where + E: FieldElement, +{ + fn write_into(&self, target: &mut W) { + let Self { numerators, denominators } = self; + numerators.write_into(target); + denominators.write_into(target); + } +} + +impl Deserializable for CircuitLayerPolys +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + Ok(Self { + numerators: MultiLinearPoly::read_from(source)?, + denominators: MultiLinearPoly::read_from(source)?, + }) + } +} + +// CIRCUIT LAYER +// =============================================================================================== + +/// Represents a layer in a [`EvaluatedCircuit`]. +/// +/// A layer is made up of a set of `n` wires, where `n` is a power of two. This is the natural +/// circuit representation of a layer, where each consecutive pair of wires are summed to yield a +/// wire in the subsequent layer of an [`EvaluatedCircuit`]. +/// +/// Note that a [`Layer`] needs to be first converted to a [`LayerPolys`] before the evaluation of +/// the layer can be proved using GKR. +pub struct CircuitLayer { + wires: Vec>, +} + +impl CircuitLayer { + /// Creates a new [`Layer`] from a set of projective coordinates. + /// + /// Panics if the number of projective coordinates is not a power of two. + pub fn new(wires: Vec>) -> Self { + assert!(wires.len().is_power_of_two()); + + Self { wires } + } + + /// Returns the wires that make up this circuit layer. + pub fn wires(&self) -> &[CircuitWire] { + &self.wires + } + + /// Returns the number of wires in the layer. + pub fn num_wires(&self) -> usize { + self.wires.len() + } +} + +// CIRCUIT WIRE +// =============================================================================================== + +/// Represents a fraction `numerator / denominator` as a pair `(numerator, denominator)`. This is +/// the type for the gates' inputs in [`prover::EvaluatedCircuit`]. +/// +/// Hence, addition is defined in the natural way fractions are added together: `a/b + c/d = (ad + +/// bc) / bd`. +#[derive(Debug, Clone, Copy)] +pub struct CircuitWire { + numerator: E, + denominator: E, +} + +impl CircuitWire +where + E: FieldElement, +{ + /// Creates new projective coordinates from a numerator and a denominator. + pub fn new(numerator: E, denominator: E) -> Self { + assert_ne!(denominator, E::ZERO); + + Self { numerator, denominator } + } +} + +impl Add for CircuitWire +where + E: FieldElement, +{ + type Output = Self; + + fn add(self, other: Self) -> Self { + let numerator = self.numerator * other.denominator + other.numerator * self.denominator; + let denominator = self.denominator * other.denominator; + + Self::new(numerator, denominator) + } +} + +/// Represents a claim to be proven by a subsequent call to the sum-check protocol. +#[derive(Debug)] +pub struct GkrClaim { + pub evaluation_point: Vec, + pub claimed_evaluation: (E, E), +} + +/// We receive our 4 multilinear polynomials which were evaluated at a random point: +/// `left_numerators` (or `p0`), `right_numerators` (or `p1`), `left_denominators` (or `q0`), and +/// `right_denominators` (or `q1`). We'll call the 4 evaluations at a random point `p0(r)`, `p1(r)`, +/// `q0(r)`, and `q1(r)`, respectively, where `r` is the random point. Note that `r` is a shorthand +/// for a tuple of random values `(r_0, ... r_{l-1})`, where `2^{l + 1}` is the number of wires in +/// the layer. +/// +/// It is important to recall how `p0` and `p1` were constructed (and analogously for `q0` and +/// `q1`). They are the `numerators` layer polynomial (or `p`) evaluations `p(0, r)` and `p(1, r)`, +/// obtained from [`MultiLinearPoly::project_least_significant_variable`]. Hence, `[p0, p1]` form +/// the evaluations of polynomial `p'(x_0) = p(x_0, r)`. Then, the round claim for `numerators`, +/// defined as `p(r_layer, r)`, is simply `p'(r_layer)`. +fn reduce_layer_claim( + left_numerators_opening: E, + right_numerators_opening: E, + left_denominators_opening: E, + right_denominators_opening: E, + r_layer: E, +) -> (E, E) +where + E: FieldElement, +{ + // This is the `numerators` layer polynomial `f(x_0) = numerators(x_0, rx_0, ..., rx_{l-1})`, + // where `rx_0, ..., rx_{l-1}` are the random variables that were sampled during the sumcheck + // round for this layer. + let numerators_univariate = + MultiLinearPoly::from_evaluations(vec![left_numerators_opening, right_numerators_opening]); + + // This is analogous to `numerators_univariate`, but for the `denominators` layer polynomial + let denominators_univariate = MultiLinearPoly::from_evaluations(vec![ + left_denominators_opening, + right_denominators_opening, + ]); + + ( + numerators_univariate.evaluate(&[r_layer]), + denominators_univariate.evaluate(&[r_layer]), + ) +} + +/// Builds the auxiliary trace column for the univariate sum-check argument. +/// +/// Following Section 5.2 in [1] and using the inner product representation of multi-linear queries, +/// we need two univariate oracles, or equivalently two columns in the auxiliary trace, namely: +/// +/// 1. The Lagrange oracle, denoted by $c(X)$ in [1], and refered to throughout the codebase by +/// the Lagrange kernel column. +/// 2. The oracle witnessing the univariate sum-check relation defined by the aforementioned inner +/// product i.e., equation (12) in [1]. This oracle is refered to throughout the codebase as +/// the s-column. +/// +/// The following function's purpose is two build the column in point 2 given the one in point 1. +/// +/// [1]: https://eprint.iacr.org/2023/1284 +pub fn build_s_column( + main_trace: &impl Trace, + gkr_data: &GkrData, + evaluator: &impl LogUpGkrEvaluator, + lagrange_kernel_col: &[E], +) -> Vec { + let c = gkr_data.compute_batched_claim(); + let main_segment = main_trace.main_segment(); + let mean = c / E::from(E::BaseField::from(main_segment.num_rows() as u32)); + + let mut result = Vec::with_capacity(main_segment.num_rows()); + let mut last_value = E::ZERO; + result.push(last_value); + + let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; + let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols()); + + for (i, item) in lagrange_kernel_col.iter().enumerate().take(main_segment.num_rows() - 1) { + main_trace.read_main_frame(i, &mut main_frame); + + evaluator.build_query(&main_frame, &[], &mut query); + let cur_value = last_value - mean + gkr_data.compute_batched_query(&query) * *item; + + result.push(cur_value); + last_value = cur_value; + } + + result +} + +/// Builds the Lagrange kernel column at a given point. +pub fn build_lagrange_column(lagrange_randomness: &[E]) -> Vec { + EqFunction::new(lagrange_randomness.into()).evaluations() +} + +#[derive(Debug, thiserror::Error)] +pub enum GkrProverError { + #[error("failed to generate the sum-check proof")] + FailedToProveSumCheck(#[from] SumCheckProverError), + #[error("failed to generate the random challenge")] + FailedToGenerateChallenge, +} diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs new file mode 100644 index 000000000..9fc8fe175 --- /dev/null +++ b/prover/src/logup_gkr/prover.rs @@ -0,0 +1,256 @@ +use alloc::vec::Vec; + +use air::{LogUpGkrEvaluator, LogUpGkrOracle}; +use crypto::{ElementHasher, RandomCoin}; +use math::FieldElement; +use sumcheck::{ + sum_check_prove_higher_degree, sumcheck_prove_plain, BeforeFinalLayerProof, CircuitOutput, + EqFunction, FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, +}; + +use super::{reduce_layer_claim, CircuitLayerPolys, EvaluatedCircuit, GkrClaim, GkrProverError}; +use crate::{matrix::ColMatrix, Trace}; + +// PROVER +// ================================================================================================ + +/// Evaluates and proves a fractional sum circuit given a set of composition polynomials. +/// +/// For the input layer of the circuit, each individual component of the quadruple +/// [p_0, p_1, q_0, q_1] is of the form: +/// +/// m(z_0, ... , z_{μ - 1}, x_0, ... , x_{ν - 1}) = \sum_{y ∈ {0,1}^μ} EQ(z, y) * g_{[y]}(f_0(x_0, +/// ... , x_{ν - 1}), ... , f_{κ - 1}(x_0, ... , x_{ν +/// - 1})) +/// +/// where: +/// +/// 1. μ is the log_2 of the number of different numerator/denominator expressions divided by two. +/// 2. [y] := \sum_{j = 0}^{μ - 1} y_j * 2^j +/// 3. κ is the number of multi-linears (i.e., main trace columns) involved in the computation of +/// the circuit (i.e., virtual bus). +/// 4. ν is the log_2 of the trace length. +/// +/// The above `m` is usually referred to as the merge of the individual composed multi-linear +/// polynomials g_{[y]}(f_0(x_0, ... , x_{ν - 1}), ... , f_{κ - 1}(x_0, ... , x_{ν - 1})). +/// +/// The composition polynomials `g` are provided as inputs and then used in order to compute the +/// evaluations of each of the four merge polynomials over {0, 1}^{μ + ν}. The resulting evaluations +/// are then used in order to evaluate the circuit. At this point, the GKR protocol is used to prove +/// the correctness of circuit evaluation. It should be noted that the input layer, which +/// corresponds to the last layer treated by the GKR protocol, is handled differently from the other +/// layers. More specifically, the sum-check protocol used for the input layer is composed of two +/// sum-check protocols, the first one works directly with the evaluations of the `m`'s over {0, +/// 1}^{μ + ν} and runs for μ - 1 rounds. After these μ - 1 rounds, and using the resulting [`RoundClaim`], +/// we run the second and final sum-check protocol for ν rounds on the composed multi-linear +/// polynomial given by +/// +/// \sum_{y ∈ {0,1}^μ} EQ(ρ', y) * g_{[y]}(f_0(x_0, ... , x_{ν - 1}), ... , f_{κ - 1}(x_0, ... , +/// x_{ν - 1})) +/// +/// where ρ' is the randomness sampled during the first sum-check protocol. +/// +/// As part of the final sum-check protocol, the openings {f_j(ρ)} are provided as part of a +/// [`FinalOpeningClaim`]. This latter claim will be proven by the STARK prover later on using the +/// auxiliary trace. +pub fn prove_gkr( + main_trace: &impl Trace, + evaluator: &impl LogUpGkrEvaluator, + public_coin: &mut impl RandomCoin, +) -> Result, GkrProverError> { + let num_logup_random_values = evaluator.get_num_rand_values(); + let mut logup_randomness: Vec = Vec::with_capacity(num_logup_random_values); + + for _ in 0..num_logup_random_values { + logup_randomness.push(public_coin.draw().expect("failed to generate randomness")); + } + + // evaluate the GKR fractional sum circuit + let circuit = EvaluatedCircuit::new(main_trace, evaluator, &logup_randomness)?; + + // include the circuit output as part of the final proof + let CircuitLayerPolys { numerators, denominators } = circuit.output_layer().clone(); + + // run the GKR prover for all layers except the input layer + let (before_final_layer_proofs, gkr_claim) = prove_intermediate_layers(circuit, public_coin)?; + + // build the MLEs of the relevant main trace columns + let main_trace_mls = + build_mls_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; + + let final_layer_proof = + prove_input_layer(evaluator, logup_randomness, main_trace_mls, gkr_claim, public_coin)?; + + Ok(GkrCircuitProof { + circuit_outputs: CircuitOutput { numerators, denominators }, + before_final_layer_proofs, + final_layer_proof, + }) +} + +/// Proves the final GKR layer which corresponds to the input circuit layer. +fn prove_input_layer< + E: FieldElement, + C: RandomCoin, + H: ElementHasher, +>( + evaluator: &impl LogUpGkrEvaluator, + log_up_randomness: Vec, + multi_linear_ext_polys: Vec>, + claim: GkrClaim, + transcript: &mut C, +) -> Result, GkrProverError> { + // parse the [GkrClaim] resulting from the previous GKR layer + let GkrClaim { evaluation_point, claimed_evaluation } = claim; + + transcript.reseed(H::hash_elements(&[claimed_evaluation.0, claimed_evaluation.1])); + let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; + let claim = claimed_evaluation.0 + claimed_evaluation.1 * r_batch; + + let proof = sum_check_prove_higher_degree( + evaluator, + evaluation_point, + claim, + r_batch, + log_up_randomness, + multi_linear_ext_polys, + transcript, + )?; + + Ok(FinalLayerProof::new(proof)) +} + +/// Builds the multi-linear extension polynomials needed to run the final sum-check of GKR for +/// LogUp-GKR. +fn build_mls_from_main_trace_segment( + oracles: &[LogUpGkrOracle], + main_trace: &ColMatrix<::BaseField>, +) -> Result>, GkrProverError> { + let mut mls = vec![]; + + for oracle in oracles { + match oracle { + LogUpGkrOracle::CurrentRow(index) => { + let col = main_trace.get_column(*index); + let values: Vec = col.iter().map(|value| E::from(*value)).collect(); + let ml = MultiLinearPoly::from_evaluations(values); + mls.push(ml) + }, + LogUpGkrOracle::NextRow(index) => { + let col = main_trace.get_column(*index); + let mut values: Vec = col.iter().map(|value| E::from(*value)).collect(); + if let Some(value) = values.last_mut() { + *value = E::ZERO + } + values.rotate_left(1); + let ml = MultiLinearPoly::from_evaluations(values); + mls.push(ml) + }, + LogUpGkrOracle::PeriodicValue(_) => unimplemented!(), + }; + } + Ok(mls) +} + +/// Proves all GKR layers except for input layer. +fn prove_intermediate_layers< + E: FieldElement, + C: RandomCoin, + H: ElementHasher, +>( + circuit: EvaluatedCircuit, + transcript: &mut C, +) -> Result<(BeforeFinalLayerProof, GkrClaim), GkrProverError> { + // absorb the circuit output layer. This corresponds to sending the four values of the output + // layer to the verifier. The verifier then replies with a challenge `r` in order to evaluate + // `p` and `q` at `r` as multi-linears. + let CircuitLayerPolys { numerators, denominators } = circuit.output_layer(); + let mut evaluations = numerators.evaluations().to_vec(); + evaluations.extend_from_slice(denominators.evaluations()); + transcript.reseed(H::hash_elements(&evaluations)); + + // generate the challenge and reduce [p0, p1, q0, q1] to [pr, qr] + let r = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; + let mut claimed_evaluation = circuit.evaluate_output_layer(r); + + let mut layer_proofs: Vec> = Vec::new(); + let mut evaluation_point = vec![r]; + + // Loop over all inner layers, from output to input. + // + // In a layered circuit, each layer is defined in terms of its predecessor. The first inner + // layer (starting from the output layer) is the first layer that has a predecessor. Here, we + // loop over all inner layers in order to iteratively reduce a layer in terms of its successor + // layer. Note that we don't include the input layer, since its predecessor layer will be + // reduced in terms of the input layer separately in `prove_final_circuit_layer`. + for inner_layer in circuit.layers().into_iter().skip(1).rev().skip(1) { + // construct the Lagrange kernel evaluated at the previous GKR round randomness + let mut eq_mle = EqFunction::ml_at(evaluation_point.into()); + + let (numerators, denominators) = inner_layer.into_numerators_denominators(); + + // run the sumcheck protocol + let proof = sum_check_prove_num_rounds_degree_3( + claimed_evaluation, + numerators, + denominators, + &mut eq_mle, + transcript, + )?; + + // sample a random challenge to reduce claims + transcript.reseed(H::hash_elements(&proof.openings_claim.openings)); + let r_layer = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; + + // reduce the claim + claimed_evaluation = { + let left_numerators_opening = proof.openings_claim.openings[0]; + let right_numerators_opening = proof.openings_claim.openings[1]; + let left_denominators_opening = proof.openings_claim.openings[2]; + let right_denominators_opening = proof.openings_claim.openings[3]; + + reduce_layer_claim( + left_numerators_opening, + right_numerators_opening, + left_denominators_opening, + right_denominators_opening, + r_layer, + ) + }; + + // collect the randomness used for the current layer + let mut ext = vec![r_layer]; + ext.extend_from_slice(&proof.openings_claim.eval_point); + evaluation_point = ext; + + layer_proofs.push(proof); + } + + Ok(( + BeforeFinalLayerProof { proof: layer_proofs }, + GkrClaim { evaluation_point, claimed_evaluation }, + )) +} + +/// Runs the sum-check prover used in all but the input layer. +#[allow(clippy::too_many_arguments)] +fn sum_check_prove_num_rounds_degree_3< + E: FieldElement, + C: RandomCoin, + H: ElementHasher, +>( + claim: (E, E), + p: MultiLinearPoly, + q: MultiLinearPoly, + eq: &mut MultiLinearPoly, + transcript: &mut C, +) -> Result, GkrProverError> { + // generate challenge to batch two sumchecks + transcript.reseed(H::hash_elements(&[claim.0, claim.1])); + let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; + let claim = claim.0 + claim.1 * r_batch; + + let proof = sumcheck_prove_plain(claim, r_batch, p, q, eq, transcript)?; + + Ok(proof) +} diff --git a/prover/src/tests/mod.rs b/prover/src/tests/mod.rs index 6b44fa0e9..5132e2025 100644 --- a/prover/src/tests/mod.rs +++ b/prover/src/tests/mod.rs @@ -9,7 +9,7 @@ use air::{ Air, AirContext, Assertion, EvaluationFrame, FieldExtension, ProofOptions, TraceInfo, TransitionConstraintDegree, }; -use math::{fields::f128::BaseElement, FieldElement, StarkField}; +use math::{fields::f64::BaseElement, FieldElement, StarkField}; use crate::TraceTable; @@ -34,7 +34,7 @@ pub fn build_fib_trace(length: usize) -> TraceTable { // ================================================================================================ pub struct MockAir { - context: AirContext, + context: AirContext, assertions: Vec>, periodic_columns: Vec>, } @@ -75,8 +75,6 @@ impl MockAir { impl Air for MockAir { type BaseField = BaseElement; type PublicInputs = (); - type GkrProof = (); - type GkrVerifier = (); fn new(trace_info: TraceInfo, _pub_inputs: (), _options: ProofOptions) -> Self { let context = build_context(trace_info, 8, 1); @@ -87,7 +85,7 @@ impl Air for MockAir { } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } @@ -115,8 +113,8 @@ fn build_context( trace_info: TraceInfo, blowup_factor: usize, num_assertions: usize, -) -> AirContext { +) -> AirContext { let options = ProofOptions::new(32, blowup_factor, 0, FieldExtension::None, 4, 31); let t_degrees = vec![TransitionConstraintDegree::new(2)]; - AirContext::new(trace_info, t_degrees, num_assertions, options) + AirContext::new(trace_info, (), t_degrees, num_assertions, options) } diff --git a/prover/src/trace/mod.rs b/prover/src/trace/mod.rs index 26b383a3b..8d7c999bf 100644 --- a/prover/src/trace/mod.rs +++ b/prover/src/trace/mod.rs @@ -5,6 +5,7 @@ use air::{Air, AuxRandElements, EvaluationFrame, LagrangeKernelBoundaryConstraint, TraceInfo}; use math::{polynom, FieldElement, StarkField}; +use sumcheck::GkrCircuitProof; use super::ColMatrix; @@ -22,7 +23,7 @@ mod tests; /// Defines an [`AuxTraceWithMetadata`] type where the type arguments use their equivalents in an /// [`Air`]. -type AirAuxTraceWithMetadata = AuxTraceWithMetadata::GkrProof>; +type AirAuxTraceWithMetadata = AuxTraceWithMetadata; // AUX TRACE WITH METADATA // ================================================================================================ @@ -30,10 +31,10 @@ type AirAuxTraceWithMetadata = AuxTraceWithMetadata::GkrProo /// Holds the auxiliary trace, the random elements used when generating the auxiliary trace, and /// optionally, a GKR proof. See [`crate::Proof`] for more information about the auxiliary /// proof. -pub struct AuxTraceWithMetadata { +pub struct AuxTraceWithMetadata { pub aux_trace: ColMatrix, pub aux_rand_elements: AuxRandElements, - pub gkr_proof: Option, + pub gkr_proof: Option>, } // TRACE TRAIT @@ -79,7 +80,7 @@ pub trait Trace: Sized { /// Returns the number of columns in the main segment of this trace. fn main_trace_width(&self) -> usize { - self.info().main_trace_width() + self.info().main_segment_width() } /// Returns the number of columns in the auxiliary trace segment. @@ -90,21 +91,18 @@ pub trait Trace: Sized { /// Checks if this trace is valid against the specified AIR, and panics if not. /// /// NOTE: this is a very expensive operation and is intended for use only in debug mode. - fn validate( - &self, - air: &A, - aux_trace_with_metadata: Option<&AirAuxTraceWithMetadata>, - ) where + fn validate(&self, air: &A, aux_trace_with_metadata: Option<&AirAuxTraceWithMetadata>) + where A: Air, E: FieldElement, { // make sure the width align; if they don't something went terribly wrong assert_eq!( self.main_trace_width(), - air.trace_info().main_trace_width(), + air.trace_info().main_segment_width(), "inconsistent trace width: expected {}, but was {}", self.main_trace_width(), - air.trace_info().main_trace_width(), + air.trace_info().main_segment_width(), ); // --- 1. make sure the assertions are valid ---------------------------------------------- diff --git a/prover/src/trace/tests.rs b/prover/src/trace/tests.rs index fc653bbde..b08771a3e 100644 --- a/prover/src/trace/tests.rs +++ b/prover/src/trace/tests.rs @@ -5,7 +5,7 @@ use alloc::vec::Vec; -use math::fields::f128::BaseElement; +use math::fields::f64::BaseElement; use crate::{tests::build_fib_trace, Trace}; diff --git a/prover/src/trace/trace_lde/default/tests.rs b/prover/src/trace/trace_lde/default/tests.rs index c06cc2e60..e1b9b6299 100644 --- a/prover/src/trace/trace_lde/default/tests.rs +++ b/prover/src/trace/trace_lde/default/tests.rs @@ -7,7 +7,7 @@ use alloc::vec::Vec; use crypto::{hashers::Blake3_256, ElementHasher, MerkleTree}; use math::{ - fields::f128::BaseElement, get_power_series, get_power_series_with_offset, polynom, + fields::f64::BaseElement, get_power_series, get_power_series_with_offset, polynom, FieldElement, StarkField, }; diff --git a/prover/src/trace/trace_table.rs b/prover/src/trace/trace_table.rs index dfbd6fe72..0d26c73a0 100644 --- a/prover/src/trace/trace_table.rs +++ b/prover/src/trace/trace_table.rs @@ -166,7 +166,7 @@ impl TraceTable { I: FnOnce(&mut [B]), U: FnMut(usize, &mut [B]), { - let mut state = vec![B::ZERO; self.info.main_trace_width()]; + let mut state = vec![B::ZERO; self.info.main_segment_width()]; init(&mut state); self.update_row(0, &state); @@ -255,7 +255,7 @@ impl TraceTable { /// Returns the number of columns in this execution trace. pub fn width(&self) -> usize { - self.info.main_trace_width() + self.info.main_segment_width() } /// Returns the entire trace column at the specified index. diff --git a/sumcheck/benches/sum_check_high_degree.rs b/sumcheck/benches/sum_check_high_degree.rs index 3db6a37e3..f32329c80 100644 --- a/sumcheck/benches/sum_check_high_degree.rs +++ b/sumcheck/benches/sum_check_high_degree.rs @@ -8,7 +8,7 @@ use std::{marker::PhantomData, time::Duration}; use air::{EvaluationFrame, LogUpGkrEvaluator, LogUpGkrOracle}; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; use crypto::{hashers::Blake3_192, DefaultRandomCoin, RandomCoin}; -use math::{fields::f64::BaseElement, ExtensionOf, FieldElement}; +use math::{fields::f64::BaseElement, ExtensionOf, FieldElement, StarkField}; use rand_utils::{rand_value, rand_vector}; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; @@ -94,22 +94,30 @@ fn setup_sum_check( } #[derive(Clone, Default)] -pub struct PlainLogUpGkrEval { +pub struct PlainLogUpGkrEval { + oracles: Vec>, _field: PhantomData, } -impl LogUpGkrEvaluator for PlainLogUpGkrEval { - type BaseField = BaseElement; - - type PublicInputs = (); - - fn get_oracles(&self) -> Vec> { +impl PlainLogUpGkrEval { + pub fn new() -> Self { let committed_0 = LogUpGkrOracle::CurrentRow(0); let committed_1 = LogUpGkrOracle::CurrentRow(1); let committed_2 = LogUpGkrOracle::CurrentRow(2); let committed_3 = LogUpGkrOracle::CurrentRow(3); let committed_4 = LogUpGkrOracle::CurrentRow(4); - vec![committed_0, committed_1, committed_2, committed_3, committed_4] + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; + Self { oracles, _field: PhantomData } + } +} + +impl LogUpGkrEvaluator for PlainLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles } fn get_num_rand_values(&self) -> usize { diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index f30db974c..5beac9bb9 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -124,8 +124,12 @@ where /// A proof for the input circuit layer i.e., the final layer in the GKR protocol. #[derive(Debug, Clone)] -pub struct FinalLayerProof { - pub proof: SumCheckProof, +pub struct FinalLayerProof(SumCheckProof); + +impl FinalLayerProof { + pub fn new(proof: SumCheckProof) -> Self { + Self(proof) + } } impl Serializable for FinalLayerProof @@ -133,8 +137,7 @@ where E: FieldElement, { fn write_into(&self, target: &mut W) { - let Self { proof } = self; - proof.write_into(target); + self.0.write_into(target); } } @@ -143,9 +146,7 @@ where E: FieldElement, { fn read_from(source: &mut R) -> Result { - Ok(Self { - proof: Deserializable::read_from(source)?, - }) + Ok(Self(Deserializable::read_from(source)?)) } } @@ -170,7 +171,7 @@ pub struct GkrCircuitProof { impl GkrCircuitProof { pub fn get_final_opening_claim(&self) -> FinalOpeningClaim { - self.final_layer_proof.proof.openings_claim.clone() + self.final_layer_proof.0.openings_claim.clone() } } @@ -181,7 +182,7 @@ where fn write_into(&self, target: &mut W) { self.circuit_outputs.write_into(target); self.before_final_layer_proofs.write_into(target); - self.final_layer_proof.proof.write_into(target); + self.final_layer_proof.0.write_into(target); } } diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index a96adee4c..691195925 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -53,7 +53,7 @@ use crate::{ /// 2. ${[w]} := \sum_i w_i \cdot 2^i$ and $w := (w_1, \cdots, w_{\mu})$. /// 3. $h_{j}$ and $g_{j}$ are multi-variate polynomials for $j = 0, \cdots, 2^{\mu} - 1$. /// 4. $n := \nu + \mu$ -/// 5. $\mathbb{B}_{\gamma} := \{0, 1\}^{\gamma}$ for positive integer $\gamma$. +/// 5. $\\B_{\gamma} := \{0, 1\}^{\gamma}$ for positive integer $\gamma$. /// /// The sum above is evaluated using a layered circuit with the equation linking the input layer /// values $p_n$ to the next layer values $p_{n-1}$ given by the following relations @@ -111,14 +111,14 @@ use crate::{ /// /// $$ /// p_{n-1}\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = -/// \sum_{y\in\mathbb{B}_{\nu}} G(y_{1}, ..., y_{\nu}) +/// \sum_{y\in\\B_{\nu}} G(y_{1}, ..., y_{\nu}) /// $$ /// /// and /// /// $$ /// q_{n-1}\left(v_2, \cdots, v_{\mu}, x_1, \cdots, x_{\nu}\right) = -/// \sum_{y\in\mathbb{B}_{\nu}} H\left(y_1, \cdots, y_{\nu} \right) +/// \sum_{y\in\\B_{\nu}} H\left(y_1, \cdots, y_{\nu} \right) /// $$ /// /// where diff --git a/sumcheck/src/verifier/mod.rs b/sumcheck/src/verifier/mod.rs index d1cfae3a4..887598cc8 100644 --- a/sumcheck/src/verifier/mod.rs +++ b/sumcheck/src/verifier/mod.rs @@ -53,6 +53,8 @@ pub fn verify_sum_check_intermediate_layers< Ok(openings_claim.clone()) } +/// Sum-check verifier for the input layer. +/// /// Verifies the final sum-check proof i.e., the one for the input layer, including the final check, /// and returns a [`FinalOpeningClaim`] to the STARK verifier in order to verify the correctness of /// the openings. @@ -64,8 +66,6 @@ pub fn verify_sum_check_input_layer, ) -> Result, SumCheckVerifierError> { - let FinalLayerProof { proof } = proof; - // generate challenge to batch sum-checks transcript.reseed(H::hash_elements(&[claim.0, claim.1])); let r_batch: E = transcript @@ -77,17 +77,17 @@ pub fn verify_sum_check_input_layer Option<&Vec> { - self.gkr_proof.as_ref() + /// Returns the GKR proof, if any. + pub fn read_gkr_proof(&self) -> Result, VerifierError> { + GkrCircuitProof::read_from_bytes( + self.gkr_proof.as_ref().expect("Expected a GKR proof but there was none"), + ) + .map_err(|err| VerifierError::ProofDeserializationError(err.to_string())) } /// Returns trace states at the specified positions of the LDE domain. This also checks if @@ -313,7 +318,7 @@ where ); // parse main trace segment queries - let main_segment_width = air.trace_info().main_trace_width(); + let main_segment_width = air.trace_info().main_segment_width(); let main_segment_queries = queries.remove(0); let (main_segment_query_proofs, main_segment_states) = main_segment_queries .parse::(air.lde_domain_size(), num_queries, main_segment_width) @@ -331,7 +336,7 @@ where let aux_trace_states = if air.trace_info().is_multi_segment() { let mut aux_trace_states = Vec::new(); let segment_queries = queries.remove(0); - let segment_width = air.trace_info().get_aux_segment_width(); + let segment_width = air.trace_info().aux_segment_width(); let (segment_query_proof, segment_trace_states) = segment_queries .parse::(air.lde_domain_size(), num_queries, segment_width) .map_err(|err| { diff --git a/verifier/src/evaluator.rs b/verifier/src/evaluator.rs index 10910a555..b26f7b926 100644 --- a/verifier/src/evaluator.rs +++ b/verifier/src/evaluator.rs @@ -92,25 +92,28 @@ pub fn evaluate_constraints>( let lagrange_coefficients = composition_coefficients .lagrange .expect("expected Lagrange kernel composition coefficients to be present"); - let lagrange_kernel_aux_rand_elements = { - let aux_rand_elements = - aux_rand_elements.expect("expected aux rand elements to be present"); - - aux_rand_elements - .lagrange() - .expect("expected lagrange rand elements to be present") - }; + let air::GkrData { + lagrange_kernel_eval_point: lagrange_kernel_evaluation_point, + openings_combining_randomness: _, + openings: _, + oracles: _, + } = aux_rand_elements + .expect("expected aux rand elements to be present") + .gkr_data() + .expect("expected LogUp-GKR rand elements to be present"); + + // Lagrange kernel constraints let lagrange_constraints = air .get_lagrange_kernel_constraints( lagrange_coefficients, - lagrange_kernel_aux_rand_elements, + &lagrange_kernel_evaluation_point, ) .expect("expected Lagrange kernel constraints to be present"); result += lagrange_constraints.transition.evaluate_and_combine::( lagrange_kernel_column_frame, - lagrange_kernel_aux_rand_elements, + &lagrange_kernel_evaluation_point, x, ); diff --git a/verifier/src/lib.rs b/verifier/src/lib.rs index 2c75ecd1d..cabbda6c7 100644 --- a/verifier/src/lib.rs +++ b/verifier/src/lib.rs @@ -38,7 +38,7 @@ pub use air::{ ConstraintCompositionCoefficients, ConstraintDivisor, DeepCompositionCoefficients, EvaluationFrame, FieldExtension, ProofOptions, TraceInfo, TransitionConstraintDegree, }; -use air::{AuxRandElements, GkrVerifier}; +use air::{AuxRandElements, LogUpGkrEvaluator}; pub use crypto; use crypto::{ElementHasher, Hasher, RandomCoin, VectorCommitment}; use fri::FriVerifier; @@ -47,6 +47,7 @@ use math::{ fields::{CubeExtension, QuadExtension}, FieldElement, ToElements, }; +use sumcheck::FinalOpeningClaim; pub use utils::{ ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, SliceReader, }; @@ -60,6 +61,9 @@ use evaluator::evaluate_constraints; mod composer; use composer::DeepComposer; +mod logup_gkr; +use logup_gkr::verify_gkr; + mod errors; pub use errors::VerifierError; @@ -175,35 +179,40 @@ where // process auxiliary trace segments (if any), to build a set of random elements for each segment let aux_trace_rand_elements = if air.trace_info().is_multi_segment() { - if air.context().has_lagrange_kernel_aux_column() { - let gkr_proof = { - let gkr_proof_serialized = channel - .read_gkr_proof() - .expect("Expected an a GKR proof because trace has lagrange kernel column"); - - Deserializable::read_from_bytes(gkr_proof_serialized) - .map_err(|err| VerifierError::ProofDeserializationError(err.to_string()))? - }; - let gkr_rand_elements = air - .get_gkr_proof_verifier::() - .verify::(gkr_proof, &mut public_coin) - .map_err(|err| VerifierError::GkrProofVerificationFailed(err.to_string()))?; - - let rand_elements = air.get_aux_rand_elements(&mut public_coin).expect( - "failed to generate the random elements needed to build the auxiliary trace", - ); + // build the set of random elements related to the auxiliary segment without the LogUp-GKR + // related ones. + let trace_rand_elements = air + .get_aux_rand_elements(&mut public_coin) + .expect("failed to generate the random elements needed to build the auxiliary trace"); + + // if LogUp-GKR is enabled, verify the attached proof and build an object which includes + // randomness and data related to LogUp-GKR + if air.context().logup_gkr_enabled() { + let gkr_proof = channel.read_gkr_proof()?; + let logup_gkr_evaluator = air.get_logup_gkr_evaluator(); + + let FinalOpeningClaim { eval_point, openings } = verify_gkr::( + air.context().public_inputs(), + &gkr_proof, + &logup_gkr_evaluator, + &mut public_coin, + ) + .map_err(|err| VerifierError::GkrProofVerificationFailed(err.to_string()))?; + + let gkr_data = logup_gkr_evaluator + .generate_univariate_iop_for_multi_linear_opening_data( + openings, + eval_point, + &mut public_coin, + ); public_coin.reseed(trace_commitments[AUX_TRACE_IDX]); - Some(AuxRandElements::new_with_gkr(rand_elements, gkr_rand_elements)) + Some(AuxRandElements::new(trace_rand_elements, Some(gkr_data))) } else { - let rand_elements = air.get_aux_rand_elements(&mut public_coin).expect( - "failed to generate the random elements needed to build the auxiliary trace", - ); - public_coin.reseed(trace_commitments[AUX_TRACE_IDX]); - Some(AuxRandElements::new(rand_elements)) + Some(AuxRandElements::new(trace_rand_elements, None)) } } else { None diff --git a/verifier/src/logup_gkr/mod.rs b/verifier/src/logup_gkr/mod.rs new file mode 100644 index 000000000..e317e0ab1 --- /dev/null +++ b/verifier/src/logup_gkr/mod.rs @@ -0,0 +1,115 @@ +use alloc::vec::Vec; + +use air::{Air, LogUpGkrEvaluator}; +use crypto::{ElementHasher, RandomCoin}; +use math::FieldElement; +use sumcheck::{ + verify_sum_check_input_layer, verify_sum_check_intermediate_layers, CircuitOutput, + FinalOpeningClaim, GkrCircuitProof, SumCheckVerifierError, +}; + +/// Verifies the validity of a GKR proof for a LogUp-GKR relation. +pub fn verify_gkr< + A: Air, + E: FieldElement, + C: RandomCoin, + H: ElementHasher, +>( + pub_inputs: &A::PublicInputs, + proof: &GkrCircuitProof, + evaluator: &impl LogUpGkrEvaluator, + transcript: &mut C, +) -> Result, VerifierError> { + let num_logup_random_values = evaluator.get_num_rand_values(); + let mut logup_randomness: Vec = Vec::with_capacity(num_logup_random_values); + + for _ in 0..num_logup_random_values { + logup_randomness.push(transcript.draw().expect("failed to generate randomness")); + } + + let GkrCircuitProof { + circuit_outputs, + before_final_layer_proofs, + final_layer_proof, + } = proof; + + let CircuitOutput { numerators, denominators } = circuit_outputs; + let p0 = numerators.evaluations()[0]; + let p1 = numerators.evaluations()[1]; + let q0 = denominators.evaluations()[0]; + let q1 = denominators.evaluations()[1]; + + // make sure that both denominators are not equal to E::ZERO + if q0 == E::ZERO || q1 == E::ZERO { + return Err(VerifierError::ZeroOutputDenominator); + } + + // check that the output matches the expected `claim` + let claim = evaluator.compute_claim(pub_inputs, &logup_randomness); + if (p0 * q1 + p1 * q0) / (q0 * q1) != claim { + return Err(VerifierError::MismatchingCircuitOutput); + } + + // generate the random challenge to reduce two claims into a single claim + let mut evaluations = numerators.evaluations().to_vec(); + evaluations.extend_from_slice(denominators.evaluations()); + transcript.reseed(H::hash_elements(&evaluations)); + let r = transcript.draw().map_err(|_| VerifierError::FailedToGenerateChallenge)?; + + // reduce the claim + let p_r = p0 + r * (p1 - p0); + let q_r = q0 + r * (q1 - q0); + let mut reduced_claim = (p_r, q_r); + + // verify all GKR layers but for the last one + let num_layers = before_final_layer_proofs.proof.len(); + let mut evaluation_point = vec![r]; + for i in 0..num_layers { + let FinalOpeningClaim { eval_point, openings } = verify_sum_check_intermediate_layers( + &before_final_layer_proofs.proof[i], + &evaluation_point, + reduced_claim, + transcript, + )?; + + // generate the random challenge to reduce two claims into a single claim + transcript.reseed(H::hash_elements(&openings)); + let r_layer = transcript.draw().map_err(|_| VerifierError::FailedToGenerateChallenge)?; + + let p0 = openings[0]; + let p1 = openings[1]; + let q0 = openings[2]; + let q1 = openings[3]; + reduced_claim = (p0 + r_layer * (p1 - p0), q0 + r_layer * (q1 - q0)); + + // collect the randomness used for the current layer + let rand_sumcheck = eval_point; + let mut ext = vec![r_layer]; + ext.extend_from_slice(&rand_sumcheck); + evaluation_point = ext; + } + + // verify the proof of the final GKR layer and pass final opening claim for verification + // to the STARK + verify_sum_check_input_layer( + evaluator, + final_layer_proof, + logup_randomness, + &evaluation_point, + reduced_claim, + transcript, + ) + .map_err(VerifierError::FailedToVerifySumCheck) +} + +#[derive(Debug, thiserror::Error)] +pub enum VerifierError { + #[error("one of the claimed circuit denominators is zero")] + ZeroOutputDenominator, + #[error("the output of the fraction circuit is not equal to the expected value")] + MismatchingCircuitOutput, + #[error("failed to generate the random challenge")] + FailedToGenerateChallenge, + #[error("failed to verify the sum-check proof")] + FailedToVerifySumCheck(#[from] SumCheckVerifierError), +} diff --git a/winterfell/src/lib.rs b/winterfell/src/lib.rs index 86c5e0345..c3da2fb6a 100644 --- a/winterfell/src/lib.rs +++ b/winterfell/src/lib.rs @@ -150,12 +150,13 @@ //! ```no_run //! use winterfell::{ //! math::{fields::f128::BaseElement, FieldElement, ToElements}, -//! Air, AirContext, Assertion, GkrVerifier, EvaluationFrame, +//! Air, AirContext, Assertion, EvaluationFrame, //! ProofOptions, TraceInfo, TransitionConstraintDegree, //! crypto::{hashers::Blake3_256, DefaultRandomCoin, MerkleTree}, //! }; //! //! // Public inputs for our computation will consist of the starting value and the end result. +//! #[derive(Clone)] //! pub struct PublicInputs { //! start: BaseElement, //! result: BaseElement, @@ -172,7 +173,7 @@ //! // the computation's context which we'll build in the constructor. The context is used //! // internally by the Winterfell prover/verifier when interpreting this AIR. //! pub struct WorkAir { -//! context: AirContext, +//! context: AirContext, //! start: BaseElement, //! result: BaseElement, //! } @@ -182,8 +183,6 @@ //! // the public inputs must look like. //! type BaseField = BaseElement; //! type PublicInputs = PublicInputs; -//! type GkrProof = (); -//! type GkrVerifier = (); //! //! // Here, we'll construct a new instance of our computation which is defined by 3 //! // parameters: starting value, number of steps, and the end result. Another way to @@ -206,7 +205,7 @@ //! let num_assertions = 2; //! //! WorkAir { -//! context: AirContext::new(trace_info, degrees, num_assertions, options), +//! context: AirContext::new(trace_info, pub_inputs.clone(), degrees, num_assertions, options), //! start: pub_inputs.start, //! result: pub_inputs.result, //! } @@ -246,7 +245,7 @@ //! //! // This is just boilerplate which is used by the Winterfell prover/verifier to retrieve //! // the context of the computation. -//! fn context(&self) -> &AirContext { +//! fn context(&self) -> &AirContext { //! &self.context //! } //! } @@ -269,6 +268,7 @@ //! # EvaluationFrame, TraceInfo, TransitionConstraintDegree, //! # }; //! # +//! # #[derive(Clone)] //! # pub struct PublicInputs { //! # start: BaseElement, //! # result: BaseElement, @@ -281,7 +281,7 @@ //! # } //! # //! # pub struct WorkAir { -//! # context: AirContext, +//! # context: AirContext, //! # start: BaseElement, //! # result: BaseElement, //! # } @@ -289,14 +289,12 @@ //! # impl Air for WorkAir { //! # type BaseField = BaseElement; //! # type PublicInputs = PublicInputs; -//! # type GkrProof = (); -//! # type GkrVerifier = (); //! # //! # fn new(trace_info: TraceInfo, pub_inputs: PublicInputs, options: ProofOptions) -> Self { //! # assert_eq!(1, trace_info.width()); //! # let degrees = vec![TransitionConstraintDegree::new(3)]; //! # WorkAir { -//! # context: AirContext::new(trace_info, degrees, 2, options), +//! # context: AirContext::new(trace_info, pub_inputs.clone(), degrees, 2, options), //! # start: pub_inputs.start, //! # result: pub_inputs.result, //! # } @@ -321,7 +319,7 @@ //! # ] //! # } //! # -//! # fn context(&self) -> &AirContext { +//! # fn context(&self) -> &AirContext { //! # &self.context //! # } //! # } @@ -418,7 +416,7 @@ //! # trace //! # } //! # -//! # +//! # #[derive(Clone)] //! # pub struct PublicInputs { //! # start: BaseElement, //! # result: BaseElement, @@ -431,7 +429,7 @@ //! # } //! # //! # pub struct WorkAir { -//! # context: AirContext, +//! # context: AirContext, //! # start: BaseElement, //! # result: BaseElement, //! # } @@ -439,14 +437,12 @@ //! # impl Air for WorkAir { //! # type BaseField = BaseElement; //! # type PublicInputs = PublicInputs; -//! # type GkrProof = (); -//! # type GkrVerifier = (); //! # //! # fn new(trace_info: TraceInfo, pub_inputs: PublicInputs, options: ProofOptions) -> Self { //! # assert_eq!(1, trace_info.width()); //! # let degrees = vec![TransitionConstraintDegree::new(3)]; //! # WorkAir { -//! # context: AirContext::new(trace_info, degrees, 2, options), +//! # context: AirContext::new(trace_info, pub_inputs.clone(), degrees, 2, options), //! # start: pub_inputs.start, //! # result: pub_inputs.result, //! # } @@ -471,7 +467,7 @@ //! # ] //! # } //! # -//! # fn context(&self) -> &AirContext { +//! # fn context(&self) -> &AirContext { //! # &self.context //! # } //! # } @@ -594,15 +590,14 @@ #[cfg(test)] extern crate std; -pub use air::{AuxRandElements, GkrVerifier}; +pub use air::{AuxRandElements, LogUpGkrEvaluator}; pub use prover::{ crypto, iterators, math, matrix, Air, AirContext, Assertion, AuxTraceWithMetadata, BoundaryConstraint, BoundaryConstraintGroup, CompositionPolyTrace, ConstraintCompositionCoefficients, ConstraintDivisor, ConstraintEvaluator, DeepCompositionCoefficients, DefaultConstraintEvaluator, DefaultTraceLde, EvaluationFrame, - FieldExtension, Proof, ProofOptions, Prover, ProverError, ProverGkrProof, StarkDomain, Trace, - TraceInfo, TraceLde, TracePolyTable, TraceTable, TraceTableFragment, - TransitionConstraintDegree, + FieldExtension, Proof, ProofOptions, Prover, ProverError, StarkDomain, Trace, TraceInfo, + TraceLde, TracePolyTable, TraceTable, TraceTableFragment, TransitionConstraintDegree, }; pub use verifier::{verify, AcceptableOptions, ByteWriter, VerifierError}; diff --git a/winterfell/src/tests.rs b/winterfell/src/tests.rs index 3757e2010..858f35574 100644 --- a/winterfell/src/tests.rs +++ b/winterfell/src/tests.rs @@ -3,30 +3,33 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use std::{vec, vec::Vec}; +use std::{marker::PhantomData, vec, vec::Vec}; -use air::{GkrRandElements, LagrangeKernelRandElements}; +use air::{ + Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients, FieldExtension, + LogUpGkrEvaluator, LogUpGkrOracle, ProofOptions, TraceInfo, +}; use crypto::MerkleTree; -use prover::{ - crypto::{hashers::Blake3_256, DefaultRandomCoin, RandomCoin}, +use math::StarkField; + +use super::*; +use crate::{ + crypto::{hashers::Blake3_256, DefaultRandomCoin}, math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, matrix::ColMatrix, + DefaultConstraintEvaluator, DefaultTraceLde, Prover, StarkDomain, TracePolyTable, }; -use super::*; - -const AUX_TRACE_WIDTH: usize = 2; - #[test] -fn test_complex_lagrange_kernel_air() { - let trace = LagrangeComplexTrace::new(2_usize.pow(10), AUX_TRACE_WIDTH); - - let prover = LagrangeComplexProver::new(AUX_TRACE_WIDTH); +fn test_logup_gkr() { + let aux_trace_width = 2; + let trace = LogUpGkrSimple::new(2_usize.pow(7), aux_trace_width); + let prover = LogUpGkrSimpleProver::new(aux_trace_width); let proof = prover.prove(trace).unwrap(); verify::< - LagrangeKernelComplexAir, + LogUpGkrSimpleAir, Blake3_256, DefaultRandomCoin>, MerkleTree>, @@ -34,26 +37,48 @@ fn test_complex_lagrange_kernel_air() { .unwrap() } -// LagrangeComplexTrace +// LogUpGkrSimple // ================================================================================================= #[derive(Clone, Debug)] -struct LagrangeComplexTrace { +struct LogUpGkrSimple { // dummy main trace main_trace: ColMatrix, info: TraceInfo, } -impl LagrangeComplexTrace { +impl LogUpGkrSimple { fn new(trace_len: usize, aux_segment_width: usize) -> Self { assert!(trace_len < u32::MAX.try_into().unwrap()); - let main_trace_col: Vec = + let table: Vec = (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); + let mut multiplicity: Vec = + (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + multiplicity[0] = BaseElement::new(3 * trace_len as u64 - 3 * 4); + multiplicity[1] = BaseElement::new(3 * 4); + + let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..4 { + values_0[i + 4] = BaseElement::ONE; + } + + let mut values_1: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..4 { + values_1[i + 4] = BaseElement::ONE; + } + + let mut values_2: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..4 { + values_2[i + 4] = BaseElement::ONE; + } Self { - main_trace: ColMatrix::new(vec![main_trace_col]), - info: TraceInfo::new_multi_segment(1, aux_segment_width, 0, trace_len, vec![]), + main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), + info: TraceInfo::new_multi_segment(5, aux_segment_width, 0, trace_len, vec![], true), } } @@ -62,7 +87,7 @@ impl LagrangeComplexTrace { } } -impl Trace for LagrangeComplexTrace { +impl Trace for LogUpGkrSimple { type BaseField = BaseElement; fn info(&self) -> &TraceInfo { @@ -75,76 +100,37 @@ impl Trace for LagrangeComplexTrace { fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame) { let next_row_idx = row_idx + 1; - assert_ne!(next_row_idx, self.len()); - self.main_trace.read_row_into(row_idx, frame.current_mut()); - self.main_trace.read_row_into(next_row_idx, frame.next_mut()); + self.main_trace.read_row_into(next_row_idx % self.len(), frame.next_mut()); } } // AIR // ================================================================================================= -#[derive(Debug, Clone, Default)] -struct DummyGkrVerifier; - -impl GkrVerifier for DummyGkrVerifier { - // `GkrProof` is log(trace_len) for this dummy example, so that the verifier knows how many aux - // random variables to generate - type GkrProof = usize; - type Error = VerifierError; - - fn verify( - &self, - gkr_proof: usize, - public_coin: &mut impl RandomCoin, - ) -> Result, Self::Error> - where - E: FieldElement, - Hasher: crypto::ElementHasher, - { - let log_trace_len = gkr_proof; - let lagrange_kernel_rand_elements: LagrangeKernelRandElements = { - let mut rand_elements = Vec::with_capacity(log_trace_len); - for _ in 0..log_trace_len { - rand_elements.push(public_coin.draw().unwrap()); - } - - LagrangeKernelRandElements::new(rand_elements) - }; - - Ok(GkrRandElements::new(lagrange_kernel_rand_elements, Vec::new())) - } +struct LogUpGkrSimpleAir { + context: AirContext, } -struct LagrangeKernelComplexAir { - context: AirContext, -} - -impl Air for LagrangeKernelComplexAir { +impl Air for LogUpGkrSimpleAir { type BaseField = BaseElement; - // `GkrProof` is log(trace_len) for this dummy example, so that the verifier knows how many aux - // random variables to generate - type GkrProof = usize; - type GkrVerifier = DummyGkrVerifier; - type PublicInputs = (); fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { Self { - context: AirContext::new_multi_segment( + context: AirContext::with_logup_gkr( trace_info, + _pub_inputs, vec![TransitionConstraintDegree::new(1)], - vec![TransitionConstraintDegree::new(1)], - 1, + vec![], 1, - Some(1), + 0, options, ), } } - fn context(&self) -> &AirContext { + fn context(&self) -> &AirContext { &self.context } @@ -183,44 +169,122 @@ impl Air for LagrangeKernelComplexAir { &self, _aux_rand_elements: &AuxRandElements, ) -> Vec> { - vec![Assertion::single(0, 0, E::ZERO)] + vec![] } - fn get_gkr_proof_verifier>( + fn get_logup_gkr_evaluator( &self, - ) -> Self::GkrVerifier { - DummyGkrVerifier + ) -> impl LogUpGkrEvaluator + { + PlainLogUpGkrEval::new() } } -// LagrangeComplexProver +#[derive(Clone, Default)] +pub struct PlainLogUpGkrEval { + oracles: Vec>, + _field: PhantomData, +} + +impl PlainLogUpGkrEval { + pub fn new() -> Self { + let committed_0 = LogUpGkrOracle::CurrentRow(0); + let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_2 = LogUpGkrOracle::CurrentRow(2); + let committed_3 = LogUpGkrOracle::CurrentRow(3); + let committed_4 = LogUpGkrOracle::CurrentRow(4); + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; + Self { oracles, _field: PhantomData } + } +} + +impl LogUpGkrEvaluator for PlainLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + fn get_num_rand_values(&self) -> usize { + 1 + } + + fn get_num_fractions(&self) -> usize { + 4 + } + + fn max_degree(&self) -> usize { + 3 + } + + fn build_query(&self, frame: &EvaluationFrame, _periodic_values: &[E], query: &mut [E]) + where + E: FieldElement, + { + query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f) + } + + fn evaluate_query( + &self, + query: &[F], + rand_values: &[E], + numerator: &mut [E], + denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + assert_eq!(numerator.len(), 4); + assert_eq!(denominator.len(), 4); + assert_eq!(query.len(), 5); + numerator[0] = E::from(query[1]); + numerator[1] = E::ONE; + numerator[2] = E::ONE; + numerator[3] = E::ONE; + + denominator[0] = rand_values[0] - E::from(query[0]); + denominator[1] = -(rand_values[0] - E::from(query[2])); + denominator[2] = -(rand_values[0] - E::from(query[3])); + denominator[3] = -(rand_values[0] - E::from(query[4])); + } + + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + E::ZERO + } +} +// Prover // ================================================================================================ -struct LagrangeComplexProver { +struct LogUpGkrSimpleProver { aux_trace_width: usize, options: ProofOptions, } -impl LagrangeComplexProver { +impl LogUpGkrSimpleProver { fn new(aux_trace_width: usize) -> Self { Self { aux_trace_width, - options: ProofOptions::new(1, 2, 0, FieldExtension::None, 2, 1), + options: ProofOptions::new(1, 8, 0, FieldExtension::None, 2, 1), } } } -impl Prover for LagrangeComplexProver { +impl Prover for LogUpGkrSimpleProver { type BaseField = BaseElement; - type Air = LagrangeKernelComplexAir; - type Trace = LagrangeComplexTrace; + type Air = LogUpGkrSimpleAir; + type Trace = LogUpGkrSimple; type HashFn = Blake3_256; type VC = MerkleTree>; type RandomCoin = DefaultRandomCoin; type TraceLde> = DefaultTraceLde; type ConstraintEvaluator<'a, E: FieldElement> = - DefaultConstraintEvaluator<'a, LagrangeKernelComplexAir, E>; + DefaultConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { } @@ -253,46 +317,16 @@ impl Prover for LagrangeComplexProver { DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients) } - fn generate_gkr_proof( - &self, - main_trace: &Self::Trace, - public_coin: &mut Self::RandomCoin, - ) -> (ProverGkrProof, GkrRandElements) + fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix where E: FieldElement, { let main_trace = main_trace.main_segment(); - let log_trace_len = main_trace.num_rows().ilog2() as usize; - let lagrange_kernel_rand_elements = { - let mut rand_elements = Vec::with_capacity(log_trace_len); - for _ in 0..log_trace_len { - rand_elements.push(public_coin.draw().unwrap()); - } - - LagrangeKernelRandElements::new(rand_elements) - }; - - (log_trace_len, GkrRandElements::new(lagrange_kernel_rand_elements, Vec::new())) - } - - fn build_aux_trace( - &self, - main_trace: &Self::Trace, - aux_rand_elements: &AuxRandElements, - ) -> ColMatrix - where - E: FieldElement, - { - let main_trace = main_trace.main_segment(); - let lagrange_kernel_rand_elements = aux_rand_elements - .lagrange() - .expect("expected lagrange random elements to be present."); let mut columns = Vec::new(); - // First all other auxiliary columns - let rand_summed = lagrange_kernel_rand_elements.iter().fold(E::ZERO, |acc, &r| acc + r); - for _ in 1..self.aux_trace_width { + let rand_summed = E::from(777_u32); + for _ in 0..self.aux_trace_width { // building a dummy auxiliary column let column = main_trace .get_column(0) @@ -303,27 +337,6 @@ impl Prover for LagrangeComplexProver { columns.push(column); } - // then build the Lagrange kernel column - { - let r = &lagrange_kernel_rand_elements; - - let mut lagrange_col = Vec::with_capacity(main_trace.num_rows()); - - for row_idx in 0..main_trace.num_rows() { - let mut row_value = E::ONE; - for (bit_idx, &r_i) in r.iter().enumerate() { - if row_idx & (1 << bit_idx) == 0 { - row_value *= E::ONE - r_i; - } else { - row_value *= r_i; - } - } - lagrange_col.push(row_value); - } - - columns.push(lagrange_col); - } - ColMatrix::new(columns) } } From ac9561d76bcaad24b65f9c9dfd7b96cfedd3c158 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Tue, 3 Sep 2024 05:45:34 +0200 Subject: [PATCH 03/19] Add support for s-column as part of the support for LogUp-GKR (#297) --- air/src/air/aux.rs | 10 +- air/src/air/context.rs | 10 +- .../air/{ => logup_gkr}/lagrange/boundary.rs | 2 +- air/src/air/{ => logup_gkr}/lagrange/frame.rs | 0 air/src/air/{ => logup_gkr}/lagrange/mod.rs | 8 +- .../{ => logup_gkr}/lagrange/transition.rs | 0 .../air/{logup_gkr.rs => logup_gkr/mod.rs} | 31 +++- air/src/air/logup_gkr/s_column.rs | 56 ++++++++ air/src/air/mod.rs | 35 +---- air/src/air/tests.rs | 1 - air/src/air/trace_info.rs | 28 +++- prover/src/constraints/evaluator/default.rs | 37 ++--- .../evaluator/{lagrange.rs => logup_gkr.rs} | 133 +++++++++++++----- prover/src/constraints/evaluator/mod.rs | 2 +- prover/src/lib.rs | 2 +- prover/src/trace/mod.rs | 63 +++++++-- sumcheck/src/lib.rs | 2 +- verifier/src/composer.rs | 2 +- verifier/src/evaluator.rs | 45 ++++-- winterfell/src/tests.rs | 4 +- 20 files changed, 340 insertions(+), 131 deletions(-) rename air/src/air/{ => logup_gkr}/lagrange/boundary.rs (97%) rename air/src/air/{ => logup_gkr}/lagrange/frame.rs (100%) rename air/src/air/{ => logup_gkr}/lagrange/mod.rs (95%) rename air/src/air/{ => logup_gkr}/lagrange/transition.rs (100%) rename air/src/air/{logup_gkr.rs => logup_gkr/mod.rs} (87%) create mode 100644 air/src/air/logup_gkr/s_column.rs rename prover/src/constraints/evaluator/{lagrange.rs => logup_gkr.rs} (70%) diff --git a/air/src/air/aux.rs b/air/src/air/aux.rs index 7dc9c4f48..33d7d8539 100644 --- a/air/src/air/aux.rs +++ b/air/src/air/aux.rs @@ -5,9 +5,9 @@ use alloc::vec::Vec; -use math::FieldElement; +use math::{ExtensionOf, FieldElement}; -use super::{lagrange::LagrangeKernelRandElements, LogUpGkrOracle}; +use super::{LagrangeKernelRandElements, LogUpGkrOracle}; /// Holds the randomly generated elements used in defining the auxiliary segment of the trace. /// @@ -130,7 +130,11 @@ impl GkrData { .fold(E::ZERO, |acc, (a, b)| acc + *a * *b) } - pub fn compute_batched_query(&self, query: &[E::BaseField]) -> E { + pub fn compute_batched_query(&self, query: &[F]) -> E + where + F: FieldElement, + E: ExtensionOf, + { E::from(query[0]) + query .iter() diff --git a/air/src/air/context.rs b/air/src/air/context.rs index c36173ca3..8998e64e0 100644 --- a/air/src/air/context.rs +++ b/air/src/air/context.rs @@ -250,13 +250,9 @@ impl AirContext { self.aux_transition_constraint_degrees.len() } - /// Returns the index of the auxiliary column which implements the Lagrange kernel, if any - pub fn lagrange_kernel_aux_column_idx(&self) -> Option { - if self.logup_gkr_enabled() { - Some(self.trace_info().aux_segment_width() - 1) - } else { - None - } + /// Returns the index of the auxiliary column which implements the Lagrange kernel, if any. + pub fn lagrange_kernel_column_idx(&self) -> Option { + self.trace_info.lagrange_kernel_column_idx() } /// Returns true if LogUp-GKR is enabled. diff --git a/air/src/air/lagrange/boundary.rs b/air/src/air/logup_gkr/lagrange/boundary.rs similarity index 97% rename from air/src/air/lagrange/boundary.rs rename to air/src/air/logup_gkr/lagrange/boundary.rs index 5d1954615..f3cc886aa 100644 --- a/air/src/air/lagrange/boundary.rs +++ b/air/src/air/logup_gkr/lagrange/boundary.rs @@ -5,7 +5,7 @@ use math::FieldElement; -use crate::{LagrangeKernelEvaluationFrame, LagrangeKernelRandElements}; +use super::{LagrangeKernelEvaluationFrame, LagrangeKernelRandElements}; #[derive(Debug, Clone, Eq, PartialEq)] pub struct LagrangeKernelBoundaryConstraint diff --git a/air/src/air/lagrange/frame.rs b/air/src/air/logup_gkr/lagrange/frame.rs similarity index 100% rename from air/src/air/lagrange/frame.rs rename to air/src/air/logup_gkr/lagrange/frame.rs diff --git a/air/src/air/lagrange/mod.rs b/air/src/air/logup_gkr/lagrange/mod.rs similarity index 95% rename from air/src/air/lagrange/mod.rs rename to air/src/air/logup_gkr/lagrange/mod.rs index fed5897f3..9d80b4437 100644 --- a/air/src/air/lagrange/mod.rs +++ b/air/src/air/logup_gkr/lagrange/mod.rs @@ -3,17 +3,18 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -mod boundary; use alloc::vec::Vec; use core::ops::Deref; +use math::FieldElement; + +mod boundary; pub use boundary::LagrangeKernelBoundaryConstraint; mod frame; pub use frame::LagrangeKernelEvaluationFrame; mod transition; -use math::FieldElement; pub use transition::LagrangeKernelTransitionConstraints; use crate::LagrangeConstraintsCompositionCoefficients; @@ -22,7 +23,6 @@ use crate::LagrangeConstraintsCompositionCoefficients; pub struct LagrangeKernelConstraints { pub transition: LagrangeKernelTransitionConstraints, pub boundary: LagrangeKernelBoundaryConstraint, - pub lagrange_kernel_col_idx: usize, } impl LagrangeKernelConstraints { @@ -30,7 +30,6 @@ impl LagrangeKernelConstraints { pub fn new( lagrange_composition_coefficients: LagrangeConstraintsCompositionCoefficients, lagrange_kernel_rand_elements: &LagrangeKernelRandElements, - lagrange_kernel_col_idx: usize, ) -> Self { Self { transition: LagrangeKernelTransitionConstraints::new( @@ -40,7 +39,6 @@ impl LagrangeKernelConstraints { lagrange_composition_coefficients.boundary, lagrange_kernel_rand_elements, ), - lagrange_kernel_col_idx, } } } diff --git a/air/src/air/lagrange/transition.rs b/air/src/air/logup_gkr/lagrange/transition.rs similarity index 100% rename from air/src/air/lagrange/transition.rs rename to air/src/air/logup_gkr/lagrange/transition.rs diff --git a/air/src/air/logup_gkr.rs b/air/src/air/logup_gkr/mod.rs similarity index 87% rename from air/src/air/logup_gkr.rs rename to air/src/air/logup_gkr/mod.rs index 0438064d9..a907fad40 100644 --- a/air/src/air/logup_gkr.rs +++ b/air/src/air/logup_gkr/mod.rs @@ -9,7 +9,15 @@ use core::marker::PhantomData; use crypto::{ElementHasher, RandomCoin}; use math::{ExtensionOf, FieldElement, StarkField, ToElements}; -use super::{EvaluationFrame, GkrData, LagrangeKernelRandElements}; +use super::{EvaluationFrame, GkrData, LagrangeConstraintsCompositionCoefficients}; +mod s_column; +use s_column::SColumnConstraint; + +mod lagrange; +pub use lagrange::{ + LagrangeKernelBoundaryConstraint, LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, + LagrangeKernelRandElements, LagrangeKernelTransitionConstraints, +}; /// A trait containing the necessary information in order to run the LogUp-GKR protocol of [1]. /// @@ -116,6 +124,27 @@ pub trait LogUpGkrEvaluator: Clone + Sync { self.get_oracles().to_vec(), ) } + + /// Returns a new [`LagrangeKernelConstraints`]. + fn get_lagrange_kernel_constraints>( + &self, + lagrange_composition_coefficients: LagrangeConstraintsCompositionCoefficients, + lagrange_kernel_rand_elements: &LagrangeKernelRandElements, + ) -> LagrangeKernelConstraints { + LagrangeKernelConstraints::new( + lagrange_composition_coefficients, + lagrange_kernel_rand_elements, + ) + } + + /// Returns a new [`SColumnConstraints`]. + fn get_s_column_constraints>( + &self, + gkr_data: GkrData, + composition_coefficient: E, + ) -> SColumnConstraint { + SColumnConstraint::new(gkr_data, composition_coefficient) + } } #[derive(Clone, Default)] diff --git a/air/src/air/logup_gkr/s_column.rs b/air/src/air/logup_gkr/s_column.rs new file mode 100644 index 000000000..29848ceeb --- /dev/null +++ b/air/src/air/logup_gkr/s_column.rs @@ -0,0 +1,56 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use math::FieldElement; + +use super::{super::Air, EvaluationFrame, GkrData}; +use crate::LogUpGkrEvaluator; + +/// Represents the transition constraint for the s-column, as well as the random coefficient used +/// to linearly combine the constraint into the constraint composition polynomial. +/// +/// The s-column implements the cohomological sum-check argument of [1] and the constraint in +/// [`SColumnConstraint`] is exactly Eq (4) in Lemma 1 in [1]. +/// +/// +/// [1]: https://eprint.iacr.org/2021/930 +pub struct SColumnConstraint { + gkr_data: GkrData, + composition_coefficient: E, +} + +impl SColumnConstraint { + pub fn new(gkr_data: GkrData, composition_coefficient: E) -> Self { + Self { gkr_data, composition_coefficient } + } + + /// Evaluates the transition constraint over the specificed main trace segment, s-column, + /// and Lagrange kernel evaluation frames. + pub fn evaluate( + &self, + air: &A, + main_trace_frame: &EvaluationFrame, + s_cur: E, + s_nxt: E, + l_cur: E, + x: E, + ) -> E + where + A: Air, + { + let batched_claim = self.gkr_data.compute_batched_claim(); + let mean = batched_claim + .mul_base(E::BaseField::ONE / E::BaseField::from(air.trace_length() as u32)); + + let mut query = vec![E::ZERO; air.get_logup_gkr_evaluator().get_oracles().len()]; + air.get_logup_gkr_evaluator().build_query(main_trace_frame, &[], &mut query); + let batched_claim_at_query = self.gkr_data.compute_batched_query::(&query); + let rhs = s_cur - mean + batched_claim_at_query * l_cur; + let lhs = s_nxt; + + let divisor = x.exp((air.trace_length() as u32).into()) - E::ONE; + self.composition_coefficient * (rhs - lhs) / divisor + } +} diff --git a/air/src/air/mod.rs b/air/src/air/mod.rs index 5dcee0717..bedfa5e35 100644 --- a/air/src/air/mod.rs +++ b/air/src/air/mod.rs @@ -6,7 +6,6 @@ use alloc::{collections::BTreeMap, vec::Vec}; use crypto::{RandomCoin, RandomCoinError}; -use logup_gkr::PhantomLogUpGkrEval; use math::{fft, ExtensibleField, ExtensionOf, FieldElement, StarkField, ToElements}; use crate::ProofOptions; @@ -29,15 +28,14 @@ pub use boundary::{BoundaryConstraint, BoundaryConstraintGroup, BoundaryConstrai mod transition; pub use transition::{EvaluationFrame, TransitionConstraintDegree, TransitionConstraints}; -mod lagrange; -pub use lagrange::{ +mod logup_gkr; +use logup_gkr::PhantomLogUpGkrEval; +pub use logup_gkr::{ LagrangeKernelBoundaryConstraint, LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, - LagrangeKernelRandElements, LagrangeKernelTransitionConstraints, + LagrangeKernelRandElements, LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, + LogUpGkrOracle, }; -mod logup_gkr; -pub use logup_gkr::{LogUpGkrEvaluator, LogUpGkrOracle}; - mod coefficients; pub use coefficients::{ ConstraintCompositionCoefficients, DeepCompositionCoefficients, @@ -331,25 +329,6 @@ pub trait Air: Send + Sync { Ok(rand_elements) } - /// Returns a new [`LagrangeKernelConstraints`] if a Lagrange kernel auxiliary column is present - /// in the trace, or `None` otherwise. - fn get_lagrange_kernel_constraints>( - &self, - lagrange_composition_coefficients: LagrangeConstraintsCompositionCoefficients, - lagrange_kernel_rand_elements: &LagrangeKernelRandElements, - ) -> Option> { - if self.context().logup_gkr_enabled() { - let col_idx = self.context().trace_info().aux_segment_width() - 1; - Some(LagrangeKernelConstraints::new( - lagrange_composition_coefficients, - lagrange_kernel_rand_elements, - col_idx, - )) - } else { - None - } - } - /// Returns values for all periodic columns used in the computation. /// /// These values will be used to compute column values at specific states of the computation @@ -600,7 +579,7 @@ pub trait Air: Send + Sync { None }; - let s_col = if self.context().logup_gkr_enabled() { + let s_col_cc = if self.context().logup_gkr_enabled() { Some(public_coin.draw()?) } else { None @@ -610,7 +589,7 @@ pub trait Air: Send + Sync { trace: t_coefficients, constraints: c_coefficients, lagrange: lagrange_cc, - s_col, + s_col: s_col_cc, }) } } diff --git a/air/src/air/tests.rs b/air/src/air/tests.rs index 5e9871ca5..2400cb883 100644 --- a/air/src/air/tests.rs +++ b/air/src/air/tests.rs @@ -225,7 +225,6 @@ impl MockAir { impl Air for MockAir { type BaseField = BaseElement; type PublicInputs = (); - //type LogUpGkrEvaluator = DummyLogUpGkrEval; fn new(trace_info: TraceInfo, _pub_inputs: (), _options: ProofOptions) -> Self { let num_assertions = trace_info.meta()[0] as usize; diff --git a/air/src/air/trace_info.rs b/air/src/air/trace_info.rs index 44aa0a7ea..8c3545539 100644 --- a/air/src/air/trace_info.rs +++ b/air/src/air/trace_info.rs @@ -39,6 +39,10 @@ impl TraceInfo { pub const MAX_META_LENGTH: usize = 65535; /// Maximum number of random elements in the auxiliary trace segment; currently set to 255. pub const MAX_RAND_SEGMENT_ELEMENTS: usize = 255; + /// The Lagrange kernel, if present, is the last column of the auxiliary trace. + pub const LAGRANGE_KERNEL_OFFSET: usize = 1; + /// The s-column, if present, is the second to last column of the auxiliary trace. + pub const S_COLUMN_OFFSET: usize = 2; // CONSTRUCTORS // -------------------------------------------------------------------------------------------- @@ -112,7 +116,7 @@ impl TraceInfo { // validate trace segment widths assert!(main_segment_width > 0, "main trace segment must consist of at least one column"); - let full_width = main_segment_width + aux_segment_width; + let full_width = main_segment_width + aux_segment_width + 2 * logup_gkr as usize; assert!( full_width <= TraceInfo::MAX_TRACE_WIDTH, "total number of columns in the trace cannot be greater than {}, but was {}", @@ -170,9 +174,9 @@ impl TraceInfo { &self.trace_meta } - /// Returns true if an execution trace contains the auxiliary trace segment. + /// Returns true if an execution trace contains an auxiliary trace segment. pub fn is_multi_segment(&self) -> bool { - self.aux_segment_width > 0 + self.aux_segment_width > 0 || self.logup_gkr } /// Returns the number of columns in the main segment of an execution trace. @@ -210,6 +214,24 @@ impl TraceInfo { self.logup_gkr } + /// Returns the index of the auxiliary column which implements the Lagrange kernel, if any. + pub fn lagrange_kernel_column_idx(&self) -> Option { + if self.logup_gkr_enabled() { + Some(self.aux_segment_width() - TraceInfo::LAGRANGE_KERNEL_OFFSET) + } else { + None + } + } + + /// Returns the index of the auxiliary column which implements the s-column, if any. + pub fn s_column_idx(&self) -> Option { + if self.logup_gkr_enabled() { + Some(self.aux_segment_width() - TraceInfo::S_COLUMN_OFFSET) + } else { + None + } + } + /// Returns the number of random elements needed to build all auxiliary columns, except for the /// Lagrange kernel column. pub fn get_num_aux_segment_rand_elements(&self) -> usize { diff --git a/prover/src/constraints/evaluator/default.rs b/prover/src/constraints/evaluator/default.rs index ea02b41d4..4373494f9 100644 --- a/prover/src/constraints/evaluator/default.rs +++ b/prover/src/constraints/evaluator/default.rs @@ -13,9 +13,9 @@ use utils::iter_mut; use utils::{iterators::*, rayon}; use super::{ - super::EvaluationTableFragment, lagrange::LagrangeKernelConstraintsBatchEvaluator, - BoundaryConstraints, CompositionPolyTrace, ConstraintEvaluationTable, ConstraintEvaluator, - PeriodicValueTable, StarkDomain, TraceLde, + super::EvaluationTableFragment, logup_gkr::LogUpGkrConstraintsEvaluator, BoundaryConstraints, + CompositionPolyTrace, ConstraintEvaluationTable, ConstraintEvaluator, PeriodicValueTable, + StarkDomain, TraceLde, }; // CONSTANTS @@ -40,7 +40,7 @@ pub struct DefaultConstraintEvaluator<'a, A: Air, E: FieldElement, transition_constraints: TransitionConstraints, - lagrange_constraints_evaluator: Option>, + logup_gkr_constraints_evaluator: Option>, aux_rand_elements: Option>, periodic_values: PeriodicValueTable, } @@ -117,10 +117,10 @@ where evaluation_table.validate_transition_degrees(); // combine all constraint evaluations into a single column, including the evaluations of the - // Lagrange kernel constraints (if present) + // LogUp-GKR constraints (if present) let combined_evaluations = { let mut constraints_evaluations = evaluation_table.combine(); - self.evaluate_lagrange_kernel_constraints(trace, domain, &mut constraints_evaluations); + self.evaluate_logup_gkr_constraints(trace, domain, &mut constraints_evaluations); constraints_evaluations }; @@ -158,18 +158,21 @@ where &composition_coefficients.boundary, ); - let lagrange_constraints_evaluator = if air.context().logup_gkr_enabled() { + let logup_gkr_constraints_evaluator = if air.context().logup_gkr_enabled() { let aux_rand_elements = aux_rand_elements.as_ref().expect("expected aux rand elements to be present"); - let lagrange_rand_elements = aux_rand_elements - .lagrange() - .expect("expected lagrange rand elements to be present"); - Some(LagrangeKernelConstraintsBatchEvaluator::new( + + Some(LogUpGkrConstraintsEvaluator::new( air, - lagrange_rand_elements.clone(), + aux_rand_elements + .gkr_data() + .expect("expected LogUp-GKR randomness to be present"), composition_coefficients .lagrange .expect("expected Lagrange kernel composition coefficients to be present"), + composition_coefficients + .s_col + .expect("expected s-column composition coefficient to be present"), )) } else { None @@ -179,7 +182,7 @@ where air, boundary_constraints, transition_constraints, - lagrange_constraints_evaluator, + logup_gkr_constraints_evaluator, aux_rand_elements, periodic_values, } @@ -295,7 +298,7 @@ where } } - /// If present, evaluates the Lagrange kernel constraints over the constraint evaluation domain. + /// If present, evaluates the LogUp-GKR constraints over the constraint evaluation domain. /// The evaluation of each constraint (both boundary and transition) is divided by its divisor, /// multiplied by its composition coefficient, the result of which is added to /// `combined_evaluations_accumulator`. @@ -303,14 +306,14 @@ where /// Specifically, `combined_evaluations_accumulator` is a buffer whose length is the size of the /// constraint evaluation domain, where each index contains combined evaluations of other /// constraints in the system. - fn evaluate_lagrange_kernel_constraints>( + fn evaluate_logup_gkr_constraints>( &self, trace: &T, domain: &StarkDomain, combined_evaluations_accumulator: &mut [E], ) { - if let Some(ref lagrange_constraints_evaluator) = self.lagrange_constraints_evaluator { - lagrange_constraints_evaluator.evaluate_constraints( + if let Some(ref logup_gkr_constraints_evaluator) = self.logup_gkr_constraints_evaluator { + logup_gkr_constraints_evaluator.evaluate_constraints( trace, domain, combined_evaluations_accumulator, diff --git a/prover/src/constraints/evaluator/lagrange.rs b/prover/src/constraints/evaluator/logup_gkr.rs similarity index 70% rename from prover/src/constraints/evaluator/lagrange.rs rename to prover/src/constraints/evaluator/logup_gkr.rs index 89d07f62c..f8fa3ae36 100644 --- a/prover/src/constraints/evaluator/lagrange.rs +++ b/prover/src/constraints/evaluator/logup_gkr.rs @@ -6,46 +6,50 @@ use alloc::vec::Vec; use air::{ - Air, LagrangeConstraintsCompositionCoefficients, LagrangeKernelConstraints, - LagrangeKernelEvaluationFrame, LagrangeKernelRandElements, + Air, EvaluationFrame, GkrData, LagrangeConstraintsCompositionCoefficients, + LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LogUpGkrEvaluator, }; use math::{batch_inversion, FieldElement}; use crate::{StarkDomain, TraceLde}; -/// Contains a specific strategy for evaluating the Lagrange kernel boundary and transition -/// constraints where the divisors' evaluation is batched. -/// -/// Specifically, [`batch_inversion`] is used to reduce the number of divisions performed. -pub struct LagrangeKernelConstraintsBatchEvaluator { +/// Contains a specific strategy for evaluating the Lagrange kernel and s-column boundary and +/// transition constraints. +pub struct LogUpGkrConstraintsEvaluator<'a, E: FieldElement, A: Air> { + air: &'a A, lagrange_kernel_constraints: LagrangeKernelConstraints, - rand_elements: LagrangeKernelRandElements, + gkr_data: GkrData, + s_col_composition_coefficient: E, } -impl LagrangeKernelConstraintsBatchEvaluator { - /// Constructs a new [`LagrangeConstraintsBatchEvaluator`]. - pub fn new( - air: &A, - lagrange_kernel_rand_elements: LagrangeKernelRandElements, +impl<'a, E, A> LogUpGkrConstraintsEvaluator<'a, E, A> +where + E: FieldElement, + A: Air, +{ + /// Constructs a new [`LogUpGkrConstraintsEvaluator`]. + pub fn new( + air: &'a A, + gkr_data: GkrData, lagrange_composition_coefficients: LagrangeConstraintsCompositionCoefficients, - ) -> Self - where - E: FieldElement, - { + s_col_composition_coefficient: E, + ) -> Self { Self { lagrange_kernel_constraints: air + .get_logup_gkr_evaluator() .get_lagrange_kernel_constraints( lagrange_composition_coefficients, - &lagrange_kernel_rand_elements, - ) - .expect("expected Lagrange kernel constraints to be present"), - rand_elements: lagrange_kernel_rand_elements, + gkr_data.lagrange_kernel_rand_elements(), + ), + air, + gkr_data, + s_col_composition_coefficient, } } /// Evaluates the transition and boundary constraints. Specifically, the constraint evaluations /// are divided by their corresponding divisors, and the resulting terms are linearly combined - /// using the composition coefficients. + /// using the constraint composition coefficients. /// /// Writes the evaluations in `combined_evaluations_acc` at the corresponding (constraint /// evaluation) domain index. @@ -64,28 +68,44 @@ impl LagrangeKernelConstraintsBatchEvaluator { ); let boundary_divisors_inv = self.compute_boundary_divisors_inv(domain); - let mut frame = LagrangeKernelEvaluationFrame::new_empty(); + let mut lagrange_frame = LagrangeKernelEvaluationFrame::new_empty(); + + let evaluator = self.air.get_logup_gkr_evaluator(); + let s_col_constraint_divisor = compute_s_col_divisor::(domain, self.air.trace_length()); + let s_col_idx = trace.trace_info().s_column_idx().expect("S-column should be present"); + let l_col_idx = trace + .trace_info() + .lagrange_kernel_column_idx() + .expect("Lagrange kernel should be present"); + let mut main_frame = EvaluationFrame::new(trace.trace_info().main_segment_width()); + let mut aux_frame = EvaluationFrame::new(trace.trace_info().aux_segment_width()); + + let c = self.gkr_data.compute_batched_claim(); + let mean = c / E::from(E::BaseField::from(trace.trace_info().length() as u32)); + let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; for step in 0..domain.ce_domain_size() { // compute Lagrange kernel frame trace.read_lagrange_kernel_frame_into( step << lde_shift, - self.lagrange_kernel_constraints.lagrange_kernel_col_idx, - &mut frame, + l_col_idx, + &mut lagrange_frame, ); // Compute the combined transition and boundary constraints evaluations for this row - let combined_evaluations = { + let lagrange_combined_evaluations = { let mut combined_evaluations = E::ZERO; // combine transition constraints for trans_constraint_idx in 0..self.lagrange_kernel_constraints.transition.num_constraints() { - let numerator = self - .lagrange_kernel_constraints - .transition - .evaluate_ith_numerator(&frame, &self.rand_elements, trans_constraint_idx); + let numerator = + self.lagrange_kernel_constraints.transition.evaluate_ith_numerator( + &lagrange_frame, + &self.gkr_data.lagrange_kernel_eval_point, + trans_constraint_idx, + ); let inv_divisor = trans_constraints_divisors .get_inverse_divisor_eval(trans_constraint_idx, step); @@ -94,8 +114,10 @@ impl LagrangeKernelConstraintsBatchEvaluator { // combine boundary constraints { - let boundary_numerator = - self.lagrange_kernel_constraints.boundary.evaluate_numerator_at(&frame); + let boundary_numerator = self + .lagrange_kernel_constraints + .boundary + .evaluate_numerator_at(&lagrange_frame); combined_evaluations += boundary_numerator * boundary_divisors_inv[step]; } @@ -103,7 +125,33 @@ impl LagrangeKernelConstraintsBatchEvaluator { combined_evaluations }; - combined_evaluations_acc[step] += combined_evaluations; + // compute and combine the transition constraints for the s-column. + // The s-column implements the cohomological sum-check argument of [1] and + // the constraint we enfore is exactly Eq (4) in Lemma 1 in [1]. + // + // [1]: https://eprint.iacr.org/2021/930 + let s_col_combined_evaluation = { + trace.read_main_trace_frame_into(step << lde_shift, &mut main_frame); + trace.read_aux_trace_frame_into(step << lde_shift, &mut aux_frame); + + let l_cur = aux_frame.current()[l_col_idx]; + let s_cur = aux_frame.current()[s_col_idx]; + let s_nxt = aux_frame.next()[s_col_idx]; + + evaluator.build_query(&main_frame, &[], &mut query); + let batched_query = self.gkr_data.compute_batched_query(&query); + + let rhs = s_cur - mean + batched_query * l_cur; + let lhs = s_nxt; + + let divisor_at_step = + s_col_constraint_divisor[step % (domain.trace_to_ce_blowup())]; + + (rhs - lhs) * self.s_col_composition_coefficient.mul_base(divisor_at_step) + }; + + combined_evaluations_acc[step] += + lagrange_combined_evaluations + s_col_combined_evaluation; } } @@ -293,3 +341,22 @@ impl TransitionDivisorEvaluator { - E::BaseField::ONE } } + +/// Computes the evaluations of the s-column divisor. +/// +/// The divisor for the s-column is $X^n - 1$ where $n$ is the trace length. This means that +/// we need only compute `ce_blowup` many values and thus only that many exponentiations. +fn compute_s_col_divisor( + domain: &StarkDomain, + trace_length: usize, +) -> Vec { + let degree = trace_length as u32; + let mut result = Vec::with_capacity(domain.trace_to_ce_blowup()); + + for row in 0..domain.trace_to_ce_blowup() { + let x = domain.get_ce_x_at(row).exp(degree.into()) - E::BaseField::ONE; + + result.push(x); + } + batch_inversion(&result) +} diff --git a/prover/src/constraints/evaluator/mod.rs b/prover/src/constraints/evaluator/mod.rs index da8a166c2..0ff6916f8 100644 --- a/prover/src/constraints/evaluator/mod.rs +++ b/prover/src/constraints/evaluator/mod.rs @@ -14,7 +14,7 @@ pub use default::DefaultConstraintEvaluator; mod boundary; use boundary::BoundaryConstraints; -mod lagrange; +mod logup_gkr; mod periodic_table; use periodic_table::PeriodicValueTable; diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 703f19d8c..b62df14d8 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -346,7 +346,7 @@ pub trait Prover { }; trace_polys - .add_aux_segment(aux_segment_polys, air.context().lagrange_kernel_aux_column_idx()); + .add_aux_segment(aux_segment_polys, air.context().lagrange_kernel_column_idx()); Some(AuxTraceWithMetadata { aux_trace, aux_rand_elements, gkr_proof }) } else { diff --git a/prover/src/trace/mod.rs b/prover/src/trace/mod.rs index 8d7c999bf..5fef475e6 100644 --- a/prover/src/trace/mod.rs +++ b/prover/src/trace/mod.rs @@ -3,7 +3,10 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use air::{Air, AuxRandElements, EvaluationFrame, LagrangeKernelBoundaryConstraint, TraceInfo}; +use air::{ + Air, AuxRandElements, EvaluationFrame, LagrangeKernelBoundaryConstraint, LogUpGkrEvaluator, + TraceInfo, +}; use math::{polynom, FieldElement, StarkField}; use sumcheck::GkrCircuitProof; @@ -139,7 +142,7 @@ pub trait Trace: Sized { } // then, check the Lagrange kernel assertion, if any - if let Some(lagrange_kernel_col_idx) = air.context().lagrange_kernel_aux_column_idx() { + if let Some(lagrange_kernel_col_idx) = air.context().lagrange_kernel_column_idx() { let boundary_constraint_assertion_value = LagrangeKernelBoundaryConstraint::assertion_value( aux_rand_elements @@ -222,19 +225,27 @@ pub trait Trace: Sized { x *= g; } - // evaluate transition constraints for Lagrange kernel column (if any) and make sure - // they all evaluate to zeros - if let Some(col_idx) = air.context().lagrange_kernel_aux_column_idx() { + // evaluate transition constraints for Lagrange kernel column and s-column, when LogUp-GKR + // is enabled, and make sure they all evaluate to zeros + if air.context().logup_gkr_enabled() { let aux_trace_with_metadata = aux_trace_with_metadata.expect("expected aux trace to be present"); let aux_trace = &aux_trace_with_metadata.aux_trace; let aux_rand_elements = &aux_trace_with_metadata.aux_rand_elements; - - let c = aux_trace.get_column(col_idx); - let v = self.length().ilog2() as usize; - let r = aux_rand_elements.lagrange().expect("expected Lagrange column to be present"); - - // Loop over every constraint + let l_col_idx = air + .context() + .trace_info() + .lagrange_kernel_column_idx() + .expect("should not be None"); + let s_col_idx = air.context().trace_info().s_column_idx().expect("should not be None"); + + let c = aux_trace.get_column(l_col_idx); + let trace_length = self.length(); + let v = trace_length.ilog2() as usize; + let gkr_data = aux_rand_elements.gkr_data().expect("should not be None"); + let r = gkr_data.lagrange_kernel_rand_elements(); + + // Loop over every Lagrange kernel constraint for constraint_idx in 1..v + 1 { let domain_step = 2_usize.pow((v - constraint_idx + 1) as u32); let domain_half_step = 2_usize.pow((v - constraint_idx) as u32); @@ -254,6 +265,36 @@ pub trait Trace: Sized { ); } } + + // Validate the s-column constraint + let evaluator = air.get_logup_gkr_evaluator(); + let mut aux_frame = EvaluationFrame::new(self.aux_trace_width()); + + let c = gkr_data.compute_batched_claim(); + let mean = c / E::from(E::BaseField::from(trace_length as u32)); + + let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; + for step in 0..self.length() { + self.read_main_frame(step, &mut main_frame); + read_aux_frame(aux_trace, step, &mut aux_frame); + + let l_cur = aux_frame.current()[l_col_idx]; + let s_cur = aux_frame.current()[s_col_idx]; + let s_nxt = aux_frame.next()[s_col_idx]; + + evaluator.build_query(&main_frame, &[], &mut query); + let batched_query = gkr_data.compute_batched_query(&query); + + let rhs = s_cur - mean + batched_query * l_cur; + let lhs = s_nxt; + + let evaluation = rhs - lhs; + + assert!( + evaluation == E::ZERO, + "s-column transition constraint did not evaluate to ZERO at step {step}" + ); + } } } } diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index 5beac9bb9..b7f670a9d 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -26,7 +26,7 @@ mod univariate; pub use univariate::{CompressedUnivariatePoly, CompressedUnivariatePolyEvals}; mod multilinear; -pub use multilinear::{EqFunction, MultiLinearPoly}; +pub use multilinear::{inner_product, EqFunction, MultiLinearPoly}; /// Represents an opening claim at an evaluation point against a batch of oracles. /// diff --git a/verifier/src/composer.rs b/verifier/src/composer.rs index 5f10ef79f..ae20f4586 100644 --- a/verifier/src/composer.rs +++ b/verifier/src/composer.rs @@ -43,7 +43,7 @@ impl DeepComposer { x_coordinates, z: [z, z * E::from(g_trace)], g_trace, - lagrange_kernel_column_idx: air.context().lagrange_kernel_aux_column_idx(), + lagrange_kernel_column_idx: air.context().lagrange_kernel_column_idx(), } } diff --git a/verifier/src/evaluator.rs b/verifier/src/evaluator.rs index b26f7b926..a226ec9c8 100644 --- a/verifier/src/evaluator.rs +++ b/verifier/src/evaluator.rs @@ -7,7 +7,7 @@ use alloc::vec::Vec; use air::{ Air, AuxRandElements, ConstraintCompositionCoefficients, EvaluationFrame, - LagrangeKernelEvaluationFrame, + LagrangeKernelEvaluationFrame, LogUpGkrEvaluator, }; use math::{polynom, FieldElement}; @@ -89,35 +89,50 @@ pub fn evaluate_constraints>( // 3 ----- evaluate Lagrange kernel constraints ------------------------------------ if let Some(lagrange_kernel_column_frame) = lagrange_kernel_frame { + let logup_gkr_evaluator = air.get_logup_gkr_evaluator(); + let lagrange_coefficients = composition_coefficients .lagrange .expect("expected Lagrange kernel composition coefficients to be present"); - let air::GkrData { - lagrange_kernel_eval_point: lagrange_kernel_evaluation_point, - openings_combining_randomness: _, - openings: _, - oracles: _, - } = aux_rand_elements + + let gkr_data = aux_rand_elements .expect("expected aux rand elements to be present") .gkr_data() .expect("expected LogUp-GKR rand elements to be present"); // Lagrange kernel constraints - let lagrange_constraints = air - .get_lagrange_kernel_constraints( - lagrange_coefficients, - &lagrange_kernel_evaluation_point, - ) - .expect("expected Lagrange kernel constraints to be present"); + let lagrange_constraints = logup_gkr_evaluator.get_lagrange_kernel_constraints( + lagrange_coefficients, + &gkr_data.lagrange_kernel_eval_point, + ); result += lagrange_constraints.transition.evaluate_and_combine::( lagrange_kernel_column_frame, - &lagrange_kernel_evaluation_point, + &gkr_data.lagrange_kernel_eval_point, x, ); - result += lagrange_constraints.boundary.evaluate_at(x, lagrange_kernel_column_frame); + + // s-column constraints + + let s_col_idx = air.trace_info().s_column_idx().expect("s-column should be present"); + + let aux_trace_frame = + aux_trace_frame.as_ref().expect("expected aux rand elements to be present"); + + let s_cur = aux_trace_frame.current()[s_col_idx]; + let s_nxt = aux_trace_frame.next()[s_col_idx]; + let l_cur = lagrange_kernel_column_frame.inner()[0]; + + let s_column_cc = composition_coefficients + .s_col + .expect("expected constraint composition coefficient for s-column to be present"); + + let s_column_constraint = + logup_gkr_evaluator.get_s_column_constraints(gkr_data, s_column_cc); + + result += s_column_constraint.evaluate(air, main_trace_frame, s_cur, s_nxt, l_cur, x); } result diff --git a/winterfell/src/tests.rs b/winterfell/src/tests.rs index 858f35574..99e52971f 100644 --- a/winterfell/src/tests.rs +++ b/winterfell/src/tests.rs @@ -22,7 +22,7 @@ use crate::{ #[test] fn test_logup_gkr() { - let aux_trace_width = 2; + let aux_trace_width = 1; let trace = LogUpGkrSimple::new(2_usize.pow(7), aux_trace_width); let prover = LogUpGkrSimpleProver::new(aux_trace_width); @@ -269,7 +269,7 @@ impl LogUpGkrSimpleProver { fn new(aux_trace_width: usize) -> Self { Self { aux_trace_width, - options: ProofOptions::new(1, 8, 0, FieldExtension::None, 2, 1), + options: ProofOptions::new(1, 8, 0, FieldExtension::Quadratic, 2, 1), } } } From b5f64cc281b8b95ff76c86c97412b783b2573877 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Tue, 10 Sep 2024 18:10:15 +0200 Subject: [PATCH 04/19] Add support for periodic columns in LogUp-GKR (#307) --- air/src/air/aux.rs | 6 +- air/src/air/logup_gkr/mod.rs | 90 ++++- air/src/air/logup_gkr/s_column.rs | 2 +- air/src/air/mod.rs | 2 +- air/src/lib.rs | 4 +- prover/src/constraints/evaluator/logup_gkr.rs | 2 +- prover/src/logup_gkr/mod.rs | 10 +- prover/src/logup_gkr/prover.rs | 20 +- prover/src/trace/mod.rs | 2 +- sumcheck/benches/sum_check_high_degree.rs | 22 +- sumcheck/src/prover/high_degree.rs | 109 +++++- sumcheck/src/verifier/mod.rs | 30 +- winterfell/src/tests/logup_gkr_periodic.rs | 357 ++++++++++++++++++ .../{tests.rs => tests/logup_gkr_simple.rs} | 9 +- winterfell/src/tests/mod.rs | 8 + 15 files changed, 619 insertions(+), 54 deletions(-) create mode 100644 winterfell/src/tests/logup_gkr_periodic.rs rename winterfell/src/{tests.rs => tests/logup_gkr_simple.rs} (97%) create mode 100644 winterfell/src/tests/mod.rs diff --git a/air/src/air/aux.rs b/air/src/air/aux.rs index 33d7d8539..d9fa3c2d5 100644 --- a/air/src/air/aux.rs +++ b/air/src/air/aux.rs @@ -80,7 +80,7 @@ pub struct GkrData { pub lagrange_kernel_eval_point: LagrangeKernelRandElements, pub openings_combining_randomness: Vec, pub openings: Vec, - pub oracles: Vec>, + pub oracles: Vec, } impl GkrData { @@ -92,7 +92,7 @@ impl GkrData { lagrange_kernel_eval_point: LagrangeKernelRandElements, openings_combining_randomness: Vec, openings: Vec, - oracles: Vec>, + oracles: Vec, ) -> Self { Self { lagrange_kernel_eval_point, @@ -116,7 +116,7 @@ impl GkrData { &self.openings } - pub fn oracles(&self) -> &[LogUpGkrOracle] { + pub fn oracles(&self) -> &[LogUpGkrOracle] { &self.oracles } diff --git a/air/src/air/logup_gkr/mod.rs b/air/src/air/logup_gkr/mod.rs index a907fad40..d3e198912 100644 --- a/air/src/air/logup_gkr/mod.rs +++ b/air/src/air/logup_gkr/mod.rs @@ -35,7 +35,13 @@ pub trait LogUpGkrEvaluator: Clone + Sync { /// Gets a list of all oracles involved in LogUp-GKR; this is intended to be used in construction of /// MLEs. - fn get_oracles(&self) -> &[LogUpGkrOracle]; + fn get_oracles(&self) -> &[LogUpGkrOracle]; + + /// A vector of virtual periodic columns defined by their values in some given cycle. + /// Note that the cycle lengths must be powers of 2. + fn get_periodic_column_values(&self) -> Vec> { + vec![] + } /// Returns the number of random values needed to evaluate a query. fn get_num_rand_values(&self) -> usize; @@ -56,7 +62,7 @@ pub trait LogUpGkrEvaluator: Clone + Sync { /// information returned from `get_oracles()`. However, this implementation is likely to be /// expensive compared to the hand-written implementation. However, we could provide a test /// which verifies that `get_oracles()` and `build_query()` methods are consistent. - fn build_query(&self, frame: &EvaluationFrame, periodic_values: &[E], query: &mut [E]) + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) where E: FieldElement; @@ -70,6 +76,7 @@ pub trait LogUpGkrEvaluator: Clone + Sync { fn evaluate_query( &self, query: &[F], + periodic_values: &[F], logup_randomness: &[E], numerators: &mut [E], denominators: &mut [E], @@ -145,6 +152,22 @@ pub trait LogUpGkrEvaluator: Clone + Sync { ) -> SColumnConstraint { SColumnConstraint::new(gkr_data, composition_coefficient) } + + /// Returns the periodic values used in the LogUp-GKR statement, either as base field element + /// during circuit evaluation or as extension field element during the run of sum-check for + /// the input layer. + fn build_periodic_values(&self) -> PeriodicTable + where + E: FieldElement, + { + let table = self + .get_periodic_column_values() + .iter() + .map(|values| values.iter().map(|x| E::from(*x)).collect()) + .collect(); + + PeriodicTable { table } + } } #[derive(Clone, Default)] @@ -175,7 +198,7 @@ where type PublicInputs = P; - fn get_oracles(&self) -> &[LogUpGkrOracle] { + fn get_oracles(&self) -> &[LogUpGkrOracle] { panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") } @@ -191,7 +214,7 @@ where panic!("LogUpGkrEvaluator method called but LogUp-GKR is not implemented") } - fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E], _query: &mut [E]) + fn build_query(&self, _frame: &EvaluationFrame, _query: &mut [E]) where E: FieldElement, { @@ -201,6 +224,7 @@ where fn evaluate_query( &self, _query: &[F], + _periodic_values: &[F], _rand_values: &[E], _numerator: &mut [E], _denominator: &mut [E], @@ -220,12 +244,62 @@ where } #[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord)] -pub enum LogUpGkrOracle { +pub enum LogUpGkrOracle { /// A column with a given index in the main trace segment. CurrentRow(usize), /// A column with a given index in the main trace segment but shifted upwards. NextRow(usize), - /// A virtual periodic column defined by its values in a given cycle. Note that the cycle length - /// must be a power of 2. - PeriodicValue(Vec), +} + +// PERIODIC COLUMNS FOR LOGUP +// ================================================================================================= + +/// Stores the periodic columns used in a LogUp-GKR statement. +/// +/// Each stored periodic column is interpreted as a multi-linear extension polynomial of the column +/// with the given periodic values. Due to the periodic nature of the values, storing, binding of +/// an argument and evaluating the said multi-linear extension can be all done linearly in the size +/// of the smallest cycle defining the periodic values. Hence we only store the values of this +/// smallest cycle. The cycle is assumed throughout to be a power of 2. +#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Eq, Ord)] +pub struct PeriodicTable { + pub table: Vec>, +} + +impl PeriodicTable +where + E: FieldElement, +{ + pub fn new(table: Vec>) -> Self { + let table = table.iter().map(|col| col.iter().map(|x| E::from(*x)).collect()).collect(); + + Self { table } + } + + pub fn num_columns(&self) -> usize { + self.table.len() + } + + pub fn table(&self) -> &[Vec] { + &self.table + } + + pub fn fill_periodic_values_at(&self, row: usize, values: &mut [E]) { + self.table + .iter() + .zip(values.iter_mut()) + .for_each(|(col, value)| *value = col[row % col.len()]) + } + + pub fn bind_least_significant_variable(&mut self, round_challenge: E) { + for col in self.table.iter_mut() { + if col.len() > 1 { + let num_evals = col.len() >> 1; + for i in 0..num_evals { + col[i] = col[i << 1] + round_challenge * (col[(i << 1) + 1] - col[i << 1]); + } + col.truncate(num_evals) + } + } + } } diff --git a/air/src/air/logup_gkr/s_column.rs b/air/src/air/logup_gkr/s_column.rs index 29848ceeb..685c6e026 100644 --- a/air/src/air/logup_gkr/s_column.rs +++ b/air/src/air/logup_gkr/s_column.rs @@ -45,7 +45,7 @@ impl SColumnConstraint { .mul_base(E::BaseField::ONE / E::BaseField::from(air.trace_length() as u32)); let mut query = vec![E::ZERO; air.get_logup_gkr_evaluator().get_oracles().len()]; - air.get_logup_gkr_evaluator().build_query(main_trace_frame, &[], &mut query); + air.get_logup_gkr_evaluator().build_query(main_trace_frame, &mut query); let batched_claim_at_query = self.gkr_data.compute_batched_query::(&query); let rhs = s_cur - mean + batched_claim_at_query * l_cur; let lhs = s_nxt; diff --git a/air/src/air/mod.rs b/air/src/air/mod.rs index bedfa5e35..cc2e82d2b 100644 --- a/air/src/air/mod.rs +++ b/air/src/air/mod.rs @@ -33,7 +33,7 @@ use logup_gkr::PhantomLogUpGkrEval; pub use logup_gkr::{ LagrangeKernelBoundaryConstraint, LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LagrangeKernelRandElements, LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, - LogUpGkrOracle, + LogUpGkrOracle, PeriodicTable, }; mod coefficients; diff --git a/air/src/lib.rs b/air/src/lib.rs index 2993306b9..39ef44d18 100644 --- a/air/src/lib.rs +++ b/air/src/lib.rs @@ -47,6 +47,6 @@ pub use air::{ DeepCompositionCoefficients, EvaluationFrame, GkrData, LagrangeConstraintsCompositionCoefficients, LagrangeKernelBoundaryConstraint, LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LagrangeKernelRandElements, - LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, TraceInfo, - TransitionConstraintDegree, TransitionConstraints, + LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable, + TraceInfo, TransitionConstraintDegree, TransitionConstraints, }; diff --git a/prover/src/constraints/evaluator/logup_gkr.rs b/prover/src/constraints/evaluator/logup_gkr.rs index f8fa3ae36..cc7390b73 100644 --- a/prover/src/constraints/evaluator/logup_gkr.rs +++ b/prover/src/constraints/evaluator/logup_gkr.rs @@ -138,7 +138,7 @@ where let s_cur = aux_frame.current()[s_col_idx]; let s_nxt = aux_frame.next()[s_col_idx]; - evaluator.build_query(&main_frame, &[], &mut query); + evaluator.build_query(&main_frame, &mut query); let batched_query = self.gkr_data.compute_batched_query(&query); let rhs = s_cur - mean + batched_query * l_cur; diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 643258ee2..2c4846369 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -109,21 +109,25 @@ impl EvaluatedCircuit { log_up_randomness: &[E], ) -> CircuitLayer { let num_fractions = evaluator.get_num_fractions(); + let periodic_values = evaluator.build_periodic_values(); + let mut input_layer_wires = Vec::with_capacity(main_trace.main_segment().num_rows() * num_fractions); let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols()); let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; + let mut periodic_values_row = vec![E::BaseField::ZERO; periodic_values.num_columns()]; let mut numerators = vec![E::ZERO; num_fractions]; let mut denominators = vec![E::ZERO; num_fractions]; for i in 0..main_trace.main_segment().num_rows() { let wires_from_trace_row = { main_trace.read_main_frame(i, &mut main_frame); - - evaluator.build_query(&main_frame, &[], &mut query); + periodic_values.fill_periodic_values_at(i, &mut periodic_values_row); + evaluator.build_query(&main_frame, &mut query); evaluator.evaluate_query( &query, + &periodic_values_row, log_up_randomness, &mut numerators, &mut denominators, @@ -379,7 +383,7 @@ pub fn build_s_column( for (i, item) in lagrange_kernel_col.iter().enumerate().take(main_segment.num_rows() - 1) { main_trace.read_main_frame(i, &mut main_frame); - evaluator.build_query(&main_frame, &[], &mut query); + evaluator.build_query(&main_frame, &mut query); let cur_value = last_value - mean + gkr_data.compute_batched_query(&query) * *item; result.push(cur_value); diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 9fc8fe175..f1a66cf35 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -1,6 +1,6 @@ use alloc::vec::Vec; -use air::{LogUpGkrEvaluator, LogUpGkrOracle}; +use air::{LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; use sumcheck::{ @@ -77,9 +77,18 @@ pub fn prove_gkr( // build the MLEs of the relevant main trace columns let main_trace_mls = build_mls_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; + // build the periodic table representing periodic columns as multi-linear extensions + let periodic_table = evaluator.build_periodic_values(); - let final_layer_proof = - prove_input_layer(evaluator, logup_randomness, main_trace_mls, gkr_claim, public_coin)?; + // run the GKR prover for the input layer + let final_layer_proof = prove_input_layer( + evaluator, + logup_randomness, + main_trace_mls, + periodic_table, + gkr_claim, + public_coin, + )?; Ok(GkrCircuitProof { circuit_outputs: CircuitOutput { numerators, denominators }, @@ -97,6 +106,7 @@ fn prove_input_layer< evaluator: &impl LogUpGkrEvaluator, log_up_randomness: Vec, multi_linear_ext_polys: Vec>, + periodic_table: PeriodicTable, claim: GkrClaim, transcript: &mut C, ) -> Result, GkrProverError> { @@ -114,6 +124,7 @@ fn prove_input_layer< r_batch, log_up_randomness, multi_linear_ext_polys, + periodic_table, transcript, )?; @@ -123,7 +134,7 @@ fn prove_input_layer< /// Builds the multi-linear extension polynomials needed to run the final sum-check of GKR for /// LogUp-GKR. fn build_mls_from_main_trace_segment( - oracles: &[LogUpGkrOracle], + oracles: &[LogUpGkrOracle], main_trace: &ColMatrix<::BaseField>, ) -> Result>, GkrProverError> { let mut mls = vec![]; @@ -146,7 +157,6 @@ fn build_mls_from_main_trace_segment( let ml = MultiLinearPoly::from_evaluations(values); mls.push(ml) }, - LogUpGkrOracle::PeriodicValue(_) => unimplemented!(), }; } Ok(mls) diff --git a/prover/src/trace/mod.rs b/prover/src/trace/mod.rs index 5fef475e6..2b2e89a9e 100644 --- a/prover/src/trace/mod.rs +++ b/prover/src/trace/mod.rs @@ -282,7 +282,7 @@ pub trait Trace: Sized { let s_cur = aux_frame.current()[s_col_idx]; let s_nxt = aux_frame.next()[s_col_idx]; - evaluator.build_query(&main_frame, &[], &mut query); + evaluator.build_query(&main_frame, &mut query); let batched_query = gkr_data.compute_batched_query(&query); let rhs = s_cur - mean + batched_query * l_cur; diff --git a/sumcheck/benches/sum_check_high_degree.rs b/sumcheck/benches/sum_check_high_degree.rs index f32329c80..483890579 100644 --- a/sumcheck/benches/sum_check_high_degree.rs +++ b/sumcheck/benches/sum_check_high_degree.rs @@ -5,7 +5,7 @@ use std::{marker::PhantomData, time::Duration}; -use air::{EvaluationFrame, LogUpGkrEvaluator, LogUpGkrOracle}; +use air::{EvaluationFrame, LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable}; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; use crypto::{hashers::Blake3_192, DefaultRandomCoin, RandomCoin}; use math::{fields::f64::BaseElement, ExtensionOf, FieldElement, StarkField}; @@ -37,7 +37,7 @@ fn sum_check_high_degree(c: &mut Criterion) { ) }, |( - (claim, r_batch, rand_pt, (ml0, ml1, ml2, ml3, ml4)), + (claim, r_batch, rand_pt, (ml0, ml1, ml2, ml3, ml4), periodic_table), evaluator, logup_randomness, transcript, @@ -52,6 +52,7 @@ fn sum_check_high_degree(c: &mut Criterion) { r_batch, logup_randomness, mls, + periodic_table, &mut transcript, ) }, @@ -76,6 +77,7 @@ fn setup_sum_check( MultiLinearPoly, MultiLinearPoly, ), + PeriodicTable, ) { let n = 1 << log_size; let table = MultiLinearPoly::from_evaluations(rand_vector(n)); @@ -83,6 +85,7 @@ fn setup_sum_check( let values_0 = MultiLinearPoly::from_evaluations(rand_vector(n)); let values_1 = MultiLinearPoly::from_evaluations(rand_vector(n)); let values_2 = MultiLinearPoly::from_evaluations(rand_vector(n)); + let periodic_table = PeriodicTable::default(); // this will not generate the correct claim with overwhelming probability but should be fine // for benchmarking @@ -90,12 +93,18 @@ fn setup_sum_check( let r_batch: E = rand_value(); let claim: E = rand_value(); - (claim, r_batch, rand_pt, (table, multiplicity, values_0, values_1, values_2)) + ( + claim, + r_batch, + rand_pt, + (table, multiplicity, values_0, values_1, values_2), + periodic_table, + ) } #[derive(Clone, Default)] pub struct PlainLogUpGkrEval { - oracles: Vec>, + oracles: Vec, _field: PhantomData, } @@ -116,7 +125,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { type PublicInputs = (); - fn get_oracles(&self) -> &[LogUpGkrOracle] { + fn get_oracles(&self) -> &[LogUpGkrOracle] { &self.oracles } @@ -132,7 +141,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { 3 } - fn build_query(&self, frame: &EvaluationFrame, _periodic_values: &[E], query: &mut [E]) + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) where E: FieldElement, { @@ -142,6 +151,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { fn evaluate_query( &self, query: &[F], + _periodic_values: &[F], rand_values: &[E], numerator: &mut [E], denominator: &mut [E], diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index 691195925..47be290d7 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -5,7 +5,7 @@ use alloc::vec::Vec; -use air::LogUpGkrEvaluator; +use air::{LogUpGkrEvaluator, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; #[cfg(feature = "concurrent")] @@ -160,6 +160,7 @@ pub fn sum_check_prove_higher_degree< r_sum_check: E, log_up_randomness: Vec, mut mls: Vec>, + mut periodic_table: PeriodicTable, coin: &mut impl RandomCoin, ) -> Result, SumCheckProverError> { let num_rounds = mls[0].num_variables(); @@ -176,8 +177,15 @@ pub fn sum_check_prove_higher_degree< let mut current_round_claim = SumCheckRoundClaim { eval_point: vec![], claim }; // run the first round of the protocol - let round_poly_evals = - sumcheck_round(&eq_mu, evaluator, &eq_nu, &mls, &log_up_randomness, r_sum_check); + let round_poly_evals = sumcheck_round( + &eq_mu, + evaluator, + &eq_nu, + &mls, + &periodic_table, + &log_up_randomness, + r_sum_check, + ); let round_poly_coefs = round_poly_evals.to_poly(current_round_claim.claim); // reseed with the s_0 polynomial @@ -198,10 +206,20 @@ pub fn sum_check_prove_higher_degree< .for_each(|ml| ml.bind_least_significant_variable(round_challenge)); eq_nu.bind_least_significant_variable(round_challenge); + // fold each periodic multi-linear using the round challenge + periodic_table.bind_least_significant_variable(round_challenge); + // run the i-th round of the protocol using the folded multi-linears for the new reduced // claim. This basically computes the s_i polynomial. - let round_poly_evals = - sumcheck_round(&eq_mu, evaluator, &eq_nu, &mls, &log_up_randomness, r_sum_check); + let round_poly_evals = sumcheck_round( + &eq_mu, + evaluator, + &eq_nu, + &mls, + &periodic_table, + &log_up_randomness, + r_sum_check, + ); // update the claim current_round_claim = new_round_claim; @@ -280,21 +298,28 @@ fn sumcheck_round( evaluator: &impl LogUpGkrEvaluator::BaseField>, eq_ml: &MultiLinearPoly, mls: &[MultiLinearPoly], + periodic_table: &PeriodicTable, log_up_randomness: &[E], r_sum_check: E, ) -> CompressedUnivariatePolyEvals { - let num_ml = mls.len(); + let num_mls = mls.len(); + let num_periodic = periodic_table.num_columns(); let num_vars = mls[0].num_variables(); let num_rounds = num_vars - 1; #[cfg(not(feature = "concurrent"))] let evaluations = { - let mut evals_one = vec![E::ZERO; num_ml]; - let mut evals_zero = vec![E::ZERO; num_ml]; - let mut evals_x = vec![E::ZERO; num_ml]; + let mut evals_one = vec![E::ZERO; num_mls]; + let mut evals_zero = vec![E::ZERO; num_mls]; + let mut evals_x = vec![E::ZERO; num_mls]; + + let mut evals_periodic_one = vec![E::ZERO; num_periodic]; + let mut evals_periodic_zero = vec![E::ZERO; num_periodic]; + let mut evals_periodic_x = vec![E::ZERO; num_periodic]; let mut eq_x = E::ZERO; - let mut deltas = vec![E::ZERO; num_ml]; + let mut deltas = vec![E::ZERO; num_mls]; + let mut deltas_periodic = vec![E::ZERO; num_periodic]; let mut eq_delta = E::ZERO; let mut numerators = vec![E::ZERO; evaluator.get_num_fractions()]; @@ -311,9 +336,14 @@ fn sumcheck_round( let eq_at_zero = eq_ml.evaluations()[2 * i]; let eq_at_one = eq_ml.evaluations()[2 * i + 1]; + // add evaluation of periodic columns + periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero); + periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_one); + // compute the evaluation at 1 evaluator.evaluate_query( &evals_one, + &evals_periodic_one, log_up_randomness, &mut numerators, &mut denominators, @@ -327,10 +357,14 @@ fn sumcheck_round( ); // compute the evaluations at 2, ..., d_max points - for i in 0..num_ml { + for i in 0..num_mls { deltas[i] = evals_one[i] - evals_zero[i]; evals_x[i] = evals_one[i]; } + for i in 0..num_periodic { + deltas_periodic[i] = evals_periodic_one[i] - evals_periodic_zero[i]; + evals_periodic_x[i] = evals_periodic_one[i]; + } eq_delta = eq_at_one - eq_at_zero; eq_x = eq_at_one; @@ -338,10 +372,16 @@ fn sumcheck_round( evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| { *evx += *delta; }); + evals_periodic_x.iter_mut().zip(deltas_periodic.iter()).for_each( + |(evx, delta)| { + *evx += *delta; + }, + ); eq_x += eq_delta; evaluator.evaluate_query( &evals_x, + &evals_periodic_x, log_up_randomness, &mut numerators, &mut denominators, @@ -371,23 +411,31 @@ fn sumcheck_round( .fold( || { ( - vec![E::ZERO; num_ml], - vec![E::ZERO; num_ml], - vec![E::ZERO; num_ml], + vec![E::ZERO; num_mls], + vec![E::ZERO; num_mls], + vec![E::ZERO; num_mls], + vec![E::ZERO; num_periodic], + vec![E::ZERO; num_periodic], + vec![E::ZERO; num_periodic], vec![E::ZERO; evaluator.max_degree()], vec![E::ZERO; evaluator.get_num_fractions()], vec![E::ZERO; evaluator.get_num_fractions()], - vec![E::ZERO; num_ml], + vec![E::ZERO; num_mls], + vec![E::ZERO; num_periodic], ) }, |( mut evals_zero, mut evals_one, mut evals_x, + mut evals_periodic_zero, + mut evals_periodic_one, + mut evals_periodic_x, mut poly_evals, mut numerators, mut denominators, mut deltas, + mut deltas_periodic, ), i| { for (j, ml) in mls.iter().enumerate() { @@ -398,9 +446,14 @@ fn sumcheck_round( let eq_at_zero = eq_ml.evaluations()[2 * i]; let eq_at_one = eq_ml.evaluations()[2 * i + 1]; + // add evaluation of periodic columns + periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero); + periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_one); + // compute the evaluation at 1 evaluator.evaluate_query( &evals_one, + &evals_periodic_one, log_up_randomness, &mut numerators, &mut denominators, @@ -414,10 +467,14 @@ fn sumcheck_round( ); // compute the evaluations at 2, ..., d_max points - for i in 0..num_ml { + for i in 0..num_mls { deltas[i] = evals_one[i] - evals_zero[i]; evals_x[i] = evals_one[i]; } + for i in 0..num_periodic { + deltas_periodic[i] = evals_periodic_one[i] - evals_periodic_zero[i]; + evals_periodic_x[i] = evals_periodic_one[i]; + } let eq_delta = eq_at_one - eq_at_zero; let mut eq_x = eq_at_one; @@ -425,10 +482,16 @@ fn sumcheck_round( evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| { *evx += *delta; }); + evals_periodic_x.iter_mut().zip(deltas_periodic.iter()).for_each( + |(evx, delta)| { + *evx += *delta; + }, + ); eq_x += eq_delta; evaluator.evaluate_query( &evals_x, + &evals_periodic_x, log_up_randomness, &mut numerators, &mut denominators, @@ -442,7 +505,19 @@ fn sumcheck_round( ); } - (evals_zero, evals_one, evals_x, poly_evals, numerators, denominators, deltas) + ( + evals_zero, + evals_one, + evals_x, + evals_periodic_zero, + evals_periodic_one, + evals_periodic_x, + poly_evals, + numerators, + denominators, + deltas, + deltas_periodic, + ) }, ) .map(|(_, _, _, poly_evals, ..)| poly_evals) diff --git a/sumcheck/src/verifier/mod.rs b/sumcheck/src/verifier/mod.rs index 887598cc8..900be4c86 100644 --- a/sumcheck/src/verifier/mod.rs +++ b/sumcheck/src/verifier/mod.rs @@ -5,13 +5,13 @@ use alloc::vec::Vec; -use air::LogUpGkrEvaluator; +use air::{LogUpGkrEvaluator, PeriodicTable}; use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; use crate::{ comb_func, evaluate_composition_poly, EqFunction, FinalLayerProof, FinalOpeningClaim, - RoundProof, SumCheckProof, SumCheckRoundClaim, + MultiLinearPoly, RoundProof, SumCheckProof, SumCheckRoundClaim, }; /// Verifies sum-check proofs, as part of the GKR proof, for all GKR layers except for the last one @@ -86,8 +86,14 @@ pub fn verify_sum_check_input_layer( + periodic_columns: PeriodicTable, + eval_point: &[E], +) -> Vec { + let mut evaluations = vec![]; + for col in periodic_columns.table() { + let ml = MultiLinearPoly::from_evaluations(col.to_vec()); + let num_variables = ml.num_variables(); + let point = &eval_point[..num_variables]; + + let evaluation = ml.evaluate(point); + evaluations.push(evaluation) + } + evaluations +} diff --git a/winterfell/src/tests/logup_gkr_periodic.rs b/winterfell/src/tests/logup_gkr_periodic.rs new file mode 100644 index 000000000..18ee38d00 --- /dev/null +++ b/winterfell/src/tests/logup_gkr_periodic.rs @@ -0,0 +1,357 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::{marker::PhantomData, vec, vec::Vec}; + +use air::{ + Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients, FieldExtension, + LogUpGkrEvaluator, LogUpGkrOracle, ProofOptions, TraceInfo, +}; +use crypto::MerkleTree; +use math::StarkField; + +use super::super::*; +use crate::{ + crypto::{hashers::Blake3_256, DefaultRandomCoin}, + math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, + matrix::ColMatrix, + DefaultConstraintEvaluator, DefaultTraceLde, Prover, StarkDomain, TracePolyTable, +}; + +#[test] +fn test_logup_gkr_periodic() { + let aux_trace_width = 1; + let trace = LogUpGkrPeriodic::new(2_usize.pow(7), aux_trace_width); + let prover = LogUpGkrPeriodicProver::new(aux_trace_width); + + let proof = prover.prove(trace).unwrap(); + + verify::< + LogUpGkrPeriodicAir, + Blake3_256, + DefaultRandomCoin>, + MerkleTree>, + >(proof, (), &AcceptableOptions::MinConjecturedSecurity(0)) + .unwrap() +} + +// LogUpGkrPeriodic +// ================================================================================================= + +#[derive(Clone, Debug)] +struct LogUpGkrPeriodic { + // dummy main trace + main_trace: ColMatrix, + info: TraceInfo, +} + +impl LogUpGkrPeriodic { + fn new(trace_len: usize, aux_segment_width: usize) -> Self { + assert!(trace_len < u32::MAX.try_into().unwrap()); + + let table: Vec = + (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); + let mut multiplicity = vec![BaseElement::ZERO; trace_len]; + multiplicity.iter_mut().step_by(8).for_each(|m| *m = BaseElement::from(3_u32)); + + let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..trace_len / 8 { + values_0[8 * i] = BaseElement::from(8 * i as u32); + } + + let mut values_1: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..trace_len / 8 { + values_1[8 * i] = BaseElement::from(8 * i as u32); + } + + let mut values_2: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + + for i in 0..trace_len / 8 { + values_2[8 * i] = BaseElement::from(8 * i as u32); + } + + Self { + main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), + info: TraceInfo::new_multi_segment(5, aux_segment_width, 0, trace_len, vec![], true), + } + } + + fn len(&self) -> usize { + self.main_trace.num_rows() + } +} + +impl Trace for LogUpGkrPeriodic { + type BaseField = BaseElement; + + fn info(&self) -> &TraceInfo { + &self.info + } + + fn main_segment(&self) -> &ColMatrix { + &self.main_trace + } + + fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame) { + let next_row_idx = row_idx + 1; + self.main_trace.read_row_into(row_idx, frame.current_mut()); + self.main_trace.read_row_into(next_row_idx % self.len(), frame.next_mut()); + } +} + +// AIR +// ================================================================================================= + +struct LogUpGkrPeriodicAir { + context: AirContext, +} + +impl Air for LogUpGkrPeriodicAir { + type BaseField = BaseElement; + type PublicInputs = (); + + fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { + Self { + context: AirContext::with_logup_gkr( + trace_info, + (), + vec![TransitionConstraintDegree::new(1)], + vec![], + 1, + 0, + options, + ), + } + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn evaluate_transition>( + &self, + frame: &EvaluationFrame, + _periodic_values: &[E], + result: &mut [E], + ) { + let current = frame.current()[0]; + let next = frame.next()[0]; + + // increments by 1 + result[0] = next - current - E::ONE; + } + + fn get_assertions(&self) -> Vec> { + vec![Assertion::single(0, 0, BaseElement::ZERO)] + } + + fn evaluate_aux_transition( + &self, + _main_frame: &EvaluationFrame, + _aux_frame: &EvaluationFrame, + _periodic_values: &[F], + _aux_rand_elements: &AuxRandElements, + _result: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + // do nothing + } + + fn get_aux_assertions>( + &self, + _aux_rand_elements: &AuxRandElements, + ) -> Vec> { + vec![] + } + + fn get_logup_gkr_evaluator( + &self, + ) -> impl LogUpGkrEvaluator + { + PeriodicLogUpGkrEval::new() + } +} + +#[derive(Clone, Default)] +pub struct PeriodicLogUpGkrEval { + oracles: Vec, + _field: PhantomData, +} + +impl PeriodicLogUpGkrEval { + pub fn new() -> Self { + let committed_0 = LogUpGkrOracle::CurrentRow(0); + let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_2 = LogUpGkrOracle::CurrentRow(2); + let committed_3 = LogUpGkrOracle::CurrentRow(3); + let committed_4 = LogUpGkrOracle::CurrentRow(4); + + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; + + Self { oracles, _field: PhantomData } + } +} + +impl LogUpGkrEvaluator for PeriodicLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + fn get_periodic_column_values(&self) -> Vec> { + vec![vec![ + Self::BaseField::ONE, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + Self::BaseField::ZERO, + ]] + } + + fn get_num_rand_values(&self) -> usize { + 1 + } + + fn get_num_fractions(&self) -> usize { + 4 + } + + fn max_degree(&self) -> usize { + 3 + } + + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) + where + E: FieldElement, + { + query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f) + } + + fn evaluate_query( + &self, + query: &[F], + periodic_values: &[F], + rand_values: &[E], + numerator: &mut [E], + denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + assert_eq!(numerator.len(), 4); + assert_eq!(denominator.len(), 4); + assert_eq!(query.len(), 5); + numerator[0] = E::from(query[1]); + numerator[1] = E::from(periodic_values[0]); + numerator[2] = E::from(periodic_values[0]); + numerator[3] = E::from(periodic_values[0]); + + denominator[0] = rand_values[0] - E::from(query[0]); + denominator[1] = -(rand_values[0] - E::from(query[2])); + denominator[2] = -(rand_values[0] - E::from(query[3])); + denominator[3] = -(rand_values[0] - E::from(query[4])); + } + + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + E::ZERO + } +} + +// Prover +// ================================================================================================ + +struct LogUpGkrPeriodicProver { + aux_trace_width: usize, + options: ProofOptions, +} + +impl LogUpGkrPeriodicProver { + fn new(aux_trace_width: usize) -> Self { + Self { + aux_trace_width, + options: ProofOptions::new(1, 8, 0, FieldExtension::Quadratic, 2, 1), + } + } +} + +impl Prover for LogUpGkrPeriodicProver { + type BaseField = BaseElement; + type Air = LogUpGkrPeriodicAir; + type Trace = LogUpGkrPeriodic; + type HashFn = Blake3_256; + type VC = MerkleTree>; + type RandomCoin = DefaultRandomCoin; + type TraceLde> = + DefaultTraceLde; + type ConstraintEvaluator<'a, E: FieldElement> = + DefaultConstraintEvaluator<'a, LogUpGkrPeriodicAir, E>; + + fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { + } + + fn options(&self) -> &ProofOptions { + &self.options + } + + fn new_trace_lde( + &self, + trace_info: &TraceInfo, + main_trace: &ColMatrix, + domain: &StarkDomain, + ) -> (Self::TraceLde, TracePolyTable) + where + E: math::FieldElement, + { + DefaultTraceLde::new(trace_info, main_trace, domain) + } + + fn new_evaluator<'a, E>( + &self, + air: &'a Self::Air, + aux_rand_elements: Option>, + composition_coefficients: ConstraintCompositionCoefficients, + ) -> Self::ConstraintEvaluator<'a, E> + where + E: math::FieldElement, + { + DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients) + } + + fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix + where + E: FieldElement, + { + let main_trace = main_trace.main_segment(); + + let mut columns = Vec::new(); + + let rand_summed = E::from(777_u32); + for _ in 0..self.aux_trace_width { + // building a dummy auxiliary column + let column = main_trace + .get_column(0) + .iter() + .map(|row_val| rand_summed.mul_base(*row_val)) + .collect(); + + columns.push(column); + } + + ColMatrix::new(columns) + } +} diff --git a/winterfell/src/tests.rs b/winterfell/src/tests/logup_gkr_simple.rs similarity index 97% rename from winterfell/src/tests.rs rename to winterfell/src/tests/logup_gkr_simple.rs index 99e52971f..0fec04c96 100644 --- a/winterfell/src/tests.rs +++ b/winterfell/src/tests/logup_gkr_simple.rs @@ -12,7 +12,7 @@ use air::{ use crypto::MerkleTree; use math::StarkField; -use super::*; +use super::super::*; use crate::{ crypto::{hashers::Blake3_256, DefaultRandomCoin}, math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, @@ -182,7 +182,7 @@ impl Air for LogUpGkrSimpleAir { #[derive(Clone, Default)] pub struct PlainLogUpGkrEval { - oracles: Vec>, + oracles: Vec, _field: PhantomData, } @@ -203,7 +203,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { type PublicInputs = (); - fn get_oracles(&self) -> &[LogUpGkrOracle] { + fn get_oracles(&self) -> &[LogUpGkrOracle] { &self.oracles } @@ -219,7 +219,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { 3 } - fn build_query(&self, frame: &EvaluationFrame, _periodic_values: &[E], query: &mut [E]) + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) where E: FieldElement, { @@ -229,6 +229,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { fn evaluate_query( &self, query: &[F], + _periodic_values: &[F], rand_values: &[E], numerator: &mut [E], denominator: &mut [E], diff --git a/winterfell/src/tests/mod.rs b/winterfell/src/tests/mod.rs new file mode 100644 index 000000000..51881e55e --- /dev/null +++ b/winterfell/src/tests/mod.rs @@ -0,0 +1,8 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +mod logup_gkr_simple; + +mod logup_gkr_periodic; From c146b4bc2346340ccb676924ef6d173275a5957e Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Thu, 12 Sep 2024 02:57:07 +0200 Subject: [PATCH 05/19] feat: relax constraint degree checks in debug mode (#311) --- prover/src/constraints/evaluation_table.rs | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/prover/src/constraints/evaluation_table.rs b/prover/src/constraints/evaluation_table.rs index 826c61253..711504e3d 100644 --- a/prover/src/constraints/evaluation_table.rs +++ b/prover/src/constraints/evaluation_table.rs @@ -217,15 +217,13 @@ impl<'a, E: FieldElement> ConstraintEvaluationTable<'a, E> { self.expected_transition_degrees, actual_degrees ); - // make sure evaluation domain size does not exceed the size required by max degree - let expected_domain_size = - core::cmp::max(max_degree, self.domain.trace_length() + 1).next_power_of_two(); - assert_eq!( - expected_domain_size, - self.num_rows(), - "incorrect constraint evaluation domain size; expected {}, but was {}", - expected_domain_size, - self.num_rows() + // make sure the actual degrees are less than or equal to the expected degree bounds + assert!( + self.expected_transition_degrees >= actual_degrees, + "transition constraint degrees do not satisfy the expected degree bounds + \nexpected degree bounds: {:>3?}\nactual degrees: {:>3?}", + self.expected_transition_degrees, + actual_degrees ); } } From 48a1d6f6e92b2c842d188172716cb157a8b4dfa5 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Thu, 12 Sep 2024 11:01:36 +0200 Subject: [PATCH 06/19] Fix degree checks (#312) * feat: relax constraint degree checks in debug mode * fix: remove equality test and restore evaluation domain sizes check --- prover/src/constraints/evaluation_table.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/prover/src/constraints/evaluation_table.rs b/prover/src/constraints/evaluation_table.rs index 711504e3d..f6906eb43 100644 --- a/prover/src/constraints/evaluation_table.rs +++ b/prover/src/constraints/evaluation_table.rs @@ -210,11 +210,13 @@ impl<'a, E: FieldElement> ConstraintEvaluationTable<'a, E> { max_degree = core::cmp::max(max_degree, degree); } - // make sure expected and actual degrees are equal - assert_eq!( - self.expected_transition_degrees, actual_degrees, - "transition constraint degrees didn't match\nexpected: {:>3?}\nactual: {:>3?}", - self.expected_transition_degrees, actual_degrees + // make sure the actual degrees are less than or equal to the expected degree bounds + assert!( + self.expected_transition_degrees >= actual_degrees, + "transition constraint degrees do not satisfy the expected degree bounds + \nexpected degree bounds: {:>3?}\nactual degrees: {:>3?}", + self.expected_transition_degrees, + actual_degrees ); // make sure the actual degrees are less than or equal to the expected degree bounds From 946a1b5f91ee66e09a22dc948959261632dd75fd Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Fri, 13 Sep 2024 03:15:35 +0200 Subject: [PATCH 07/19] fix: remove duplicate code and restore domain size check (#313) --- prover/src/constraints/evaluation_table.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/prover/src/constraints/evaluation_table.rs b/prover/src/constraints/evaluation_table.rs index f6906eb43..08c9167f2 100644 --- a/prover/src/constraints/evaluation_table.rs +++ b/prover/src/constraints/evaluation_table.rs @@ -219,13 +219,15 @@ impl<'a, E: FieldElement> ConstraintEvaluationTable<'a, E> { actual_degrees ); - // make sure the actual degrees are less than or equal to the expected degree bounds - assert!( - self.expected_transition_degrees >= actual_degrees, - "transition constraint degrees do not satisfy the expected degree bounds - \nexpected degree bounds: {:>3?}\nactual degrees: {:>3?}", - self.expected_transition_degrees, - actual_degrees + // make sure evaluation domain size does not exceed the size required by max degree + let expected_domain_size = + core::cmp::max(max_degree, self.domain.trace_length() + 1).next_power_of_two(); + assert_eq!( + expected_domain_size, + self.num_rows(), + "incorrect constraint evaluation domain size; expected {}, but was {}", + expected_domain_size, + self.num_rows() ); } } From 0a0f244e01e170fa7cf76ada36ad019817dec936 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Fri, 13 Sep 2024 10:03:10 +0200 Subject: [PATCH 08/19] Cleanup AirContext (#314) --- air/src/air/context.rs | 46 ++++++---------------- air/src/air/trace_info.rs | 6 +++ winterfell/src/tests/logup_gkr_periodic.rs | 2 +- winterfell/src/tests/logup_gkr_simple.rs | 2 +- 4 files changed, 20 insertions(+), 36 deletions(-) diff --git a/air/src/air/context.rs b/air/src/air/context.rs index 8998e64e0..6eb035af9 100644 --- a/air/src/air/context.rs +++ b/air/src/air/context.rs @@ -26,7 +26,6 @@ pub struct AirContext { pub(super) trace_domain_generator: B, pub(super) lde_domain_generator: B, pub(super) num_transition_exemptions: usize, - pub(super) logup_gkr: bool, } impl AirContext { @@ -106,15 +105,17 @@ impl AirContext { ); assert!(num_main_assertions > 0, "at least one assertion must be specified"); - if trace_info.is_multi_segment() && !trace_info.logup_gkr_enabled() { - assert!( - !aux_transition_constraint_degrees.is_empty(), - "at least one transition constraint degree must be specified for the auxiliary trace segment" - ); - assert!( - num_aux_assertions > 0, - "at least one assertion must be specified against the auxiliary trace segment" - ); + if trace_info.is_multi_segment() { + if !trace_info.logup_gkr_enabled() { + assert!( + !aux_transition_constraint_degrees.is_empty(), + "at least one transition constraint degree must be specified for the auxiliary trace segment" + ); + assert!( + num_aux_assertions > 0, + "at least one assertion must be specified against the auxiliary trace segment" + ); + } } else { assert!( aux_transition_constraint_degrees.is_empty(), @@ -163,32 +164,9 @@ impl AirContext { trace_domain_generator: B::get_root_of_unity(trace_length.ilog2()), lde_domain_generator: B::get_root_of_unity(lde_domain_size.ilog2()), num_transition_exemptions: 1, - logup_gkr: false, } } - pub fn with_logup_gkr( - trace_info: TraceInfo, - pub_inputs: P, - main_transition_constraint_degrees: Vec, - aux_transition_constraint_degrees: Vec, - num_main_assertions: usize, - num_aux_assertions: usize, - options: ProofOptions, - ) -> Self { - let mut air_context = Self::new_multi_segment( - trace_info, - pub_inputs, - main_transition_constraint_degrees, - aux_transition_constraint_degrees, - num_main_assertions, - num_aux_assertions, - options, - ); - air_context.logup_gkr = true; - air_context - } - // PUBLIC ACCESSORS // -------------------------------------------------------------------------------------------- @@ -257,7 +235,7 @@ impl AirContext { /// Returns true if LogUp-GKR is enabled. pub fn logup_gkr_enabled(&self) -> bool { - self.logup_gkr + self.trace_info.logup_gkr_enabled() } /// Returns the total number of assertions defined for a computation, excluding the Lagrange diff --git a/air/src/air/trace_info.rs b/air/src/air/trace_info.rs index 8c3545539..29cf4726b 100644 --- a/air/src/air/trace_info.rs +++ b/air/src/air/trace_info.rs @@ -175,6 +175,9 @@ impl TraceInfo { } /// Returns true if an execution trace contains an auxiliary trace segment. + /// + /// This includes either the case when the auxiliary trace segment is user defined or the case + /// when the segment is created as part of LogUp-GKR. pub fn is_multi_segment(&self) -> bool { self.aux_segment_width > 0 || self.logup_gkr } @@ -187,6 +190,9 @@ impl TraceInfo { } /// Returns the number of columns in the auxiliary segment of an execution trace. + /// + /// This includes both the columns that are user defined as well as the two columns defined + /// as part of LogUp-GKR when the latter is enabled. pub fn aux_segment_width(&self) -> usize { self.aux_segment_width + 2 * self.logup_gkr as usize } diff --git a/winterfell/src/tests/logup_gkr_periodic.rs b/winterfell/src/tests/logup_gkr_periodic.rs index 18ee38d00..d6ae0c530 100644 --- a/winterfell/src/tests/logup_gkr_periodic.rs +++ b/winterfell/src/tests/logup_gkr_periodic.rs @@ -116,7 +116,7 @@ impl Air for LogUpGkrPeriodicAir { fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { Self { - context: AirContext::with_logup_gkr( + context: AirContext::new_multi_segment( trace_info, (), vec![TransitionConstraintDegree::new(1)], diff --git a/winterfell/src/tests/logup_gkr_simple.rs b/winterfell/src/tests/logup_gkr_simple.rs index 0fec04c96..3ffe8ea3b 100644 --- a/winterfell/src/tests/logup_gkr_simple.rs +++ b/winterfell/src/tests/logup_gkr_simple.rs @@ -118,7 +118,7 @@ impl Air for LogUpGkrSimpleAir { fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { Self { - context: AirContext::with_logup_gkr( + context: AirContext::new_multi_segment( trace_info, _pub_inputs, vec![TransitionConstraintDegree::new(1)], From f522feceadf6ffbb5dad596073594a7b6a29cd20 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Tue, 17 Sep 2024 09:43:21 +0200 Subject: [PATCH 09/19] Add LogUp-GKR benchmark (#315) --- prover/Cargo.toml | 4 + prover/benches/logup_gkr.rs | 368 ++++++++++++++++++++++++++++++ sumcheck/Cargo.toml | 2 - sumcheck/benches/bind_variable.rs | 70 +----- sumcheck/benches/eq_function.rs | 73 +----- 5 files changed, 389 insertions(+), 128 deletions(-) create mode 100644 prover/benches/logup_gkr.rs diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 6fef7f90f..37e45c472 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -15,6 +15,10 @@ rust-version = "1.78" [lib] bench = false +[[bench]] +name = "logup_gkr" +harness = false + [[bench]] name = "row_matrix" harness = false diff --git a/prover/benches/logup_gkr.rs b/prover/benches/logup_gkr.rs new file mode 100644 index 000000000..0ad845f4e --- /dev/null +++ b/prover/benches/logup_gkr.rs @@ -0,0 +1,368 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::{marker::PhantomData, time::Duration, vec::Vec}; + +use air::{ + Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients, + EvaluationFrame, FieldExtension, LogUpGkrEvaluator, LogUpGkrOracle, ProofOptions, TraceInfo, + TransitionConstraintDegree, +}; +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use crypto::MerkleTree; +use math::StarkField; +use winter_prover::{ + crypto::{hashers::Blake3_256, DefaultRandomCoin}, + math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, + matrix::ColMatrix, + DefaultConstraintEvaluator, DefaultTraceLde, Prover, StarkDomain, Trace, TracePolyTable, +}; + +const TRACE_LENS: [usize; 2] = [2_usize.pow(20), 2_usize.pow(21)]; +const AUX_TRACE_WIDTH: usize = 2; + +/// Simple end-to-end benchmark for LogUp-GKR. +/// +/// The main trace contains `5` columns and the LogUp relation is a simple one where we have: +/// +/// 1. a table of values from `0` to `trace_len - 1`. +/// 2. a multiplicity column containing the number of look ups for each value in the table. +/// 3. three columns with values contained in the table above. +/// +/// Given the above, the benchmark then gives an idea about the minimal overhead due to enabling +/// LogUp-GKR. The overhead could be bigger depending on the complexity of the LogUp relation. +fn prove_with_lagrange_kernel(c: &mut Criterion) { + let mut group = c.benchmark_group("prove with Lagrange kernel column"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(20)); + + for &trace_len in TRACE_LENS.iter() { + group.bench_function(BenchmarkId::new("", trace_len), |b| { + let trace = LogUpGkrSimpleTrace::new(trace_len, AUX_TRACE_WIDTH); + let prover = LogUpGkrSimpleProver::new(AUX_TRACE_WIDTH); + + b.iter_batched( + || trace.clone(), + |trace| prover.prove(trace).unwrap(), + BatchSize::SmallInput, + ) + }); + } +} + +criterion_group!(lagrange_kernel_group, prove_with_lagrange_kernel); +criterion_main!(lagrange_kernel_group); + +// LogUpGkrSimple +// ================================================================================================= + +#[derive(Clone, Debug)] +struct LogUpGkrSimpleTrace { + // dummy main trace + main_trace: ColMatrix, + info: TraceInfo, +} + +impl LogUpGkrSimpleTrace { + fn new(trace_len: usize, aux_segment_width: usize) -> Self { + assert!(trace_len < u32::MAX.try_into().unwrap()); + + // we create a column for the table we are looking values into. These are just the integers + // from 0 to `trace_len`. + let table: Vec = + (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); + + // we create three columns that contains values contained in `table`. For simplicity, we + // look up only the values `0` or `1`, we look up the value `1` four times and the value `0` + // `trace_len - 4` times. + + let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_0[i + 4] = BaseElement::ONE; + } + + let mut values_1: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_1[i + 4] = BaseElement::ONE; + } + + let mut values_2: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_2[i + 4] = BaseElement::ONE; + } + + // we create the multiplicity column + let mut multiplicity: Vec = + (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + // we look up the value `1` four times in three columns + multiplicity[1] = BaseElement::new(3 * 4); + // we look up the value `0` `trace_len - 4` in three columns + multiplicity[0] = BaseElement::new(3 * trace_len as u64 - 3 * 4); + + Self { + main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), + info: TraceInfo::new_multi_segment(5, aux_segment_width, 0, trace_len, vec![], true), + } + } + + fn len(&self) -> usize { + self.main_trace.num_rows() + } +} + +impl Trace for LogUpGkrSimpleTrace { + type BaseField = BaseElement; + + fn info(&self) -> &TraceInfo { + &self.info + } + + fn main_segment(&self) -> &ColMatrix { + &self.main_trace + } + + fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame) { + let next_row_idx = row_idx + 1; + self.main_trace.read_row_into(row_idx, frame.current_mut()); + self.main_trace.read_row_into(next_row_idx % self.len(), frame.next_mut()); + } +} + +// AIR +// ================================================================================================= + +struct LogUpGkrSimpleAir { + context: AirContext, +} + +impl Air for LogUpGkrSimpleAir { + type BaseField = BaseElement; + type PublicInputs = (); + + fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { + Self { + context: AirContext::new_multi_segment( + trace_info, + _pub_inputs, + vec![TransitionConstraintDegree::new(1)], + vec![], + 1, + 0, + options, + ), + } + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn evaluate_transition>( + &self, + frame: &EvaluationFrame, + _periodic_values: &[E], + result: &mut [E], + ) { + let current = frame.current()[0]; + let next = frame.next()[0]; + + // increments by 1 + result[0] = next - current - E::ONE; + } + + fn get_assertions(&self) -> Vec> { + vec![Assertion::single(0, 0, BaseElement::ZERO)] + } + + fn evaluate_aux_transition( + &self, + _main_frame: &EvaluationFrame, + _aux_frame: &EvaluationFrame, + _periodic_values: &[F], + _aux_rand_elements: &AuxRandElements, + _result: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + // do nothing + } + + fn get_aux_assertions>( + &self, + _aux_rand_elements: &AuxRandElements, + ) -> Vec> { + vec![] + } + + fn get_logup_gkr_evaluator( + &self, + ) -> impl LogUpGkrEvaluator + { + PlainLogUpGkrEval::new() + } +} + +#[derive(Clone, Default)] +pub struct PlainLogUpGkrEval { + oracles: Vec, + _field: PhantomData, +} + +impl PlainLogUpGkrEval { + pub fn new() -> Self { + let committed_0 = LogUpGkrOracle::CurrentRow(0); + let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_2 = LogUpGkrOracle::CurrentRow(2); + let committed_3 = LogUpGkrOracle::CurrentRow(3); + let committed_4 = LogUpGkrOracle::CurrentRow(4); + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; + Self { oracles, _field: PhantomData } + } +} + +impl LogUpGkrEvaluator for PlainLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + fn get_num_rand_values(&self) -> usize { + 1 + } + + fn get_num_fractions(&self) -> usize { + 4 + } + + fn max_degree(&self) -> usize { + 3 + } + + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) + where + E: FieldElement, + { + query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f) + } + + fn evaluate_query( + &self, + query: &[F], + _periodic_values: &[F], + rand_values: &[E], + numerator: &mut [E], + denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + assert_eq!(numerator.len(), 4); + assert_eq!(denominator.len(), 4); + assert_eq!(query.len(), 5); + numerator[0] = E::from(query[1]); + numerator[1] = E::ONE; + numerator[2] = E::ONE; + numerator[3] = E::ONE; + + denominator[0] = rand_values[0] - E::from(query[0]); + denominator[1] = -(rand_values[0] - E::from(query[2])); + denominator[2] = -(rand_values[0] - E::from(query[3])); + denominator[3] = -(rand_values[0] - E::from(query[4])); + } + + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + E::ZERO + } +} +// Prover +// ================================================================================================ + +struct LogUpGkrSimpleProver { + aux_trace_width: usize, + options: ProofOptions, +} + +impl LogUpGkrSimpleProver { + fn new(aux_trace_width: usize) -> Self { + Self { + aux_trace_width, + options: ProofOptions::new(1, 8, 0, FieldExtension::Quadratic, 2, 1), + } + } +} + +impl Prover for LogUpGkrSimpleProver { + type BaseField = BaseElement; + type Air = LogUpGkrSimpleAir; + type Trace = LogUpGkrSimpleTrace; + type HashFn = Blake3_256; + type VC = MerkleTree>; + type RandomCoin = DefaultRandomCoin; + type TraceLde> = + DefaultTraceLde; + type ConstraintEvaluator<'a, E: FieldElement> = + DefaultConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; + + fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { + } + + fn options(&self) -> &ProofOptions { + &self.options + } + + fn new_trace_lde( + &self, + trace_info: &TraceInfo, + main_trace: &ColMatrix, + domain: &StarkDomain, + ) -> (Self::TraceLde, TracePolyTable) + where + E: math::FieldElement, + { + DefaultTraceLde::new(trace_info, main_trace, domain) + } + + fn new_evaluator<'a, E>( + &self, + air: &'a Self::Air, + aux_rand_elements: Option>, + composition_coefficients: ConstraintCompositionCoefficients, + ) -> Self::ConstraintEvaluator<'a, E> + where + E: math::FieldElement, + { + DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients) + } + + fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix + where + E: FieldElement, + { + let main_trace = main_trace.main_segment(); + + let mut columns = Vec::new(); + + let rand_summed = E::from(777_u32); + for _ in 0..self.aux_trace_width { + // building a dummy auxiliary column + let column = main_trace + .get_column(0) + .iter() + .map(|row_val| rand_summed.mul_base(*row_val)) + .collect(); + + columns.push(column); + } + + ColMatrix::new(columns) + } +} diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 97dc79f3a..7db2e8058 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -23,12 +23,10 @@ harness = false [[bench]] name = "eq_function" harness = false -required-features = ["concurrent"] [[bench]] name = "bind_variable" harness = false -required-features = ["concurrent"] [features] concurrent = ["utils/concurrent", "dep:rayon", "std"] diff --git a/sumcheck/benches/bind_variable.rs b/sumcheck/benches/bind_variable.rs index f7e82126c..07bbcc18d 100644 --- a/sumcheck/benches/bind_variable.rs +++ b/sumcheck/benches/bind_variable.rs @@ -6,29 +6,29 @@ use std::time::Duration; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; -use math::{fields::f64::BaseElement, FieldElement}; +use math::fields::f64::BaseElement; use rand_utils::{rand_value, rand_vector}; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; +use winter_sumcheck::MultiLinearPoly; const POLY_SIZE: [usize; 2] = [1 << 18, 1 << 20]; -fn bind_variable_serial(c: &mut Criterion) { - let mut group = c.benchmark_group("Bind variable evaluations"); +fn bind_variable(c: &mut Criterion) { + let mut group = c.benchmark_group("bind variable "); group.sample_size(10); group.measurement_time(Duration::from_secs(10)); for &poly_size in POLY_SIZE.iter() { - group.bench_function(BenchmarkId::new("serial", poly_size), |b| { + group.bench_function(BenchmarkId::new("", poly_size), |b| { b.iter_batched( || { let random_challenge: BaseElement = rand_value(); - let poly_evals: Vec = rand_vector(poly_size); - (random_challenge, poly_evals) + let poly = MultiLinearPoly::from_evaluations(rand_vector(poly_size)); + (random_challenge, poly) }, - |(random_challenge, poly_evals)| { - let mut poly_evals = poly_evals; - bind_least_significant_variable_serial(&mut poly_evals, random_challenge) + |(random_challenge, mut poly)| { + poly.bind_least_significant_variable(random_challenge) }, BatchSize::SmallInput, ) @@ -36,55 +36,5 @@ fn bind_variable_serial(c: &mut Criterion) { } } -fn bind_variable_parallel(c: &mut Criterion) { - let mut group = c.benchmark_group("Bind variable function evaluations"); - group.sample_size(10); - group.measurement_time(Duration::from_secs(10)); - - for &poly_size in POLY_SIZE.iter() { - group.bench_function(BenchmarkId::new("parallel", poly_size), |b| { - b.iter_batched( - || { - let random_challenge: BaseElement = rand_value(); - let poly_evals: Vec = rand_vector(poly_size); - (random_challenge, poly_evals) - }, - |(random_challenge, poly_evals)| { - let mut poly_evals = poly_evals; - bind_least_significant_variable_parallel(&mut poly_evals, random_challenge) - }, - BatchSize::SmallInput, - ) - }); - } -} - -fn bind_least_significant_variable_serial( - evaluations: &mut Vec, - round_challenge: E, -) { - let num_evals = evaluations.len() >> 1; - - for i in 0..num_evals { - evaluations[i] = evaluations[i << 1] - + round_challenge * (evaluations[(i << 1) + 1] - evaluations[i << 1]); - } - evaluations.truncate(num_evals); -} - -fn bind_least_significant_variable_parallel( - evaluations: &mut Vec, - round_challenge: E, -) { - let num_evals = evaluations.len() >> 1; - - let mut result = unsafe { utils::uninit_vector(num_evals) }; - result.par_iter_mut().enumerate().for_each(|(i, ev)| { - *ev = evaluations[i << 1] - + round_challenge * (evaluations[(i << 1) + 1] - evaluations[i << 1]) - }); - *evaluations = result -} - -criterion_group!(group, bind_variable_serial, bind_variable_parallel); +criterion_group!(group, bind_variable); criterion_main!(group); diff --git a/sumcheck/benches/eq_function.rs b/sumcheck/benches/eq_function.rs index 86e6cad98..df2326f95 100644 --- a/sumcheck/benches/eq_function.rs +++ b/sumcheck/benches/eq_function.rs @@ -6,91 +6,32 @@ use std::time::Duration; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; -use math::{fields::f64::BaseElement, FieldElement}; +use math::fields::f64::BaseElement; use rand_utils::rand_vector; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; +use winter_sumcheck::EqFunction; const LOG_POLY_SIZE: [usize; 2] = [18, 20]; -fn evaluate_eq_serial(c: &mut Criterion) { +fn evaluate_eq(c: &mut Criterion) { let mut group = c.benchmark_group("EQ function evaluations"); group.sample_size(10); group.measurement_time(Duration::from_secs(10)); for &log_poly_size in LOG_POLY_SIZE.iter() { - group.bench_function(BenchmarkId::new("serial", log_poly_size), |b| { + group.bench_function(BenchmarkId::new("", log_poly_size), |b| { b.iter_batched( || { let randomness: Vec = rand_vector(log_poly_size); - randomness + EqFunction::new(randomness.into()) }, - |rand| eq_evaluations(&rand), + |eq_function| eq_function.evaluations(), BatchSize::SmallInput, ) }); } } -fn evaluate_eq_parallel(c: &mut Criterion) { - let mut group = c.benchmark_group("EQ function evaluations"); - group.sample_size(10); - group.measurement_time(Duration::from_secs(10)); - - for &log_poly_size in LOG_POLY_SIZE.iter() { - group.bench_function(BenchmarkId::new("parallel", log_poly_size), |b| { - b.iter_batched( - || { - let randomness: Vec = rand_vector(log_poly_size); - randomness - }, - |rand| eq_evaluations_par(&rand), - BatchSize::SmallInput, - ) - }); - } -} - -fn eq_evaluations(query: &[E]) -> Vec { - let n = 1 << query.len(); - let mut evals = unsafe { utils::uninit_vector(n) }; - - let mut size = 1; - evals[0] = E::ONE; - for r_i in query.iter() { - let (left_evals, right_evals) = evals.split_at_mut(size); - left_evals.iter_mut().zip(right_evals.iter_mut()).for_each(|(left, right)| { - let factor = *left; - *right = factor * *r_i; - *left -= *right; - }); - - size *= 2; - } - evals -} - -fn eq_evaluations_par(query: &[E]) -> Vec { - let n = 1 << query.len(); - let mut evals = unsafe { utils::uninit_vector(n) }; - - let mut size = 1; - evals[0] = E::ONE; - for r_i in query.iter() { - let (left_evals, right_evals) = evals.split_at_mut(size); - left_evals - .par_iter_mut() - .zip(right_evals.par_iter_mut()) - .for_each(|(left, right)| { - let factor = *left; - *right = factor * *r_i; - *left -= *right; - }); - - size <<= 1; - } - evals -} - -criterion_group!(group, evaluate_eq_serial, evaluate_eq_parallel); +criterion_group!(group, evaluate_eq); criterion_main!(group); From 2713a951610643bde4141fc14c42c2a11f687d1c Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 18 Sep 2024 08:30:48 +0200 Subject: [PATCH 10/19] Parallelize Lagrange constraints evaluation (#317) --- air/src/air/logup_gkr/lagrange/boundary.rs | 15 +- air/src/air/logup_gkr/lagrange/frame.rs | 30 +- air/src/air/logup_gkr/lagrange/transition.rs | 17 +- air/src/proof/ood_frame.rs | 2 +- prover/benches/logup_gkr.rs | 8 +- prover/src/constraints/evaluation_table.rs | 29 +- prover/src/constraints/evaluator/default.rs | 76 +-- prover/src/constraints/evaluator/logup_gkr.rs | 316 ++++++------- .../evaluator/logup_gkr_evaluator.rs | 433 ++++++++++++++++++ prover/src/constraints/evaluator/mod.rs | 3 + prover/src/constraints/mod.rs | 2 +- prover/src/lib.rs | 2 +- prover/src/trace/mod.rs | 2 +- prover/src/trace/trace_lde/default/mod.rs | 7 +- winterfell/src/lib.rs | 5 +- winterfell/src/tests/logup_gkr_periodic.rs | 6 +- winterfell/src/tests/logup_gkr_simple.rs | 6 +- 17 files changed, 670 insertions(+), 289 deletions(-) create mode 100644 prover/src/constraints/evaluator/logup_gkr_evaluator.rs diff --git a/air/src/air/logup_gkr/lagrange/boundary.rs b/air/src/air/logup_gkr/lagrange/boundary.rs index f3cc886aa..3eaad9f5d 100644 --- a/air/src/air/logup_gkr/lagrange/boundary.rs +++ b/air/src/air/logup_gkr/lagrange/boundary.rs @@ -31,27 +31,28 @@ where } } + /// Returns the constraint composition coefficient for this boundary constraint. + pub fn constraint_composition_coefficient(&self) -> E { + self.composition_coefficient + } + /// Returns the evaluation of the boundary constraint at `x`, multiplied by the composition /// coefficient. /// /// `frame` is the evaluation frame of the Lagrange kernel column `c`, starting at `c(x)` pub fn evaluate_at(&self, x: E, frame: &LagrangeKernelEvaluationFrame) -> E { - let numerator = self.evaluate_numerator_at(frame); + let numerator = self.evaluate_numerator_at(frame) * self.composition_coefficient; let denominator = self.evaluate_denominator_at(x); numerator / denominator } - /// Returns the evaluation of the boundary constraint numerator, multiplied by the composition - /// coefficient. + /// Returns the evaluation of the boundary constraint numerator. /// /// `frame` is the evaluation frame of the Lagrange kernel column `c`, starting at `c(x)` for /// some `x` pub fn evaluate_numerator_at(&self, frame: &LagrangeKernelEvaluationFrame) -> E { - let trace_value = frame.inner()[0]; - let constraint_evaluation = trace_value - self.assertion_value; - - constraint_evaluation * self.composition_coefficient + frame[0] - self.assertion_value } /// Returns the evaluation of the boundary constraint denominator at point `x`. diff --git a/air/src/air/logup_gkr/lagrange/frame.rs b/air/src/air/logup_gkr/lagrange/frame.rs index d0ffc4fa4..6dc0a64cc 100644 --- a/air/src/air/logup_gkr/lagrange/frame.rs +++ b/air/src/air/logup_gkr/lagrange/frame.rs @@ -4,6 +4,7 @@ // LICENSE file in the root directory of this source tree. use alloc::vec::Vec; +use core::ops::{Index, IndexMut}; use math::{polynom, FieldElement, StarkField}; @@ -25,14 +26,15 @@ impl LagrangeKernelEvaluationFrame { // -------------------------------------------------------------------------------------------- /// Constructs a Lagrange kernel evaluation frame from the raw column polynomial evaluations. - pub fn new(frame: Vec) -> Self { + pub fn with_values(frame: Vec) -> Self { Self { frame } } /// Constructs an empty Lagrange kernel evaluation frame from the raw column polynomial /// evaluations. The frame can subsequently be filled using [`Self::frame_mut`]. - pub fn new_empty() -> Self { - Self { frame: Vec::new() } + pub fn new(trace_len: usize) -> Self { + let frame_length = trace_len.ilog2() as usize + 1; + Self { frame: vec![E::ZERO; frame_length] } } /// Constructs the frame from the Lagrange kernel column trace polynomial coefficients for an @@ -61,14 +63,6 @@ impl LagrangeKernelEvaluationFrame { Self { frame } } - // MUTATORS - // -------------------------------------------------------------------------------------------- - - /// Returns a mutable reference to the inner frame. - pub fn frame_mut(&mut self) -> &mut Vec { - &mut self.frame - } - // ACCESSORS // -------------------------------------------------------------------------------------------- @@ -84,3 +78,17 @@ impl LagrangeKernelEvaluationFrame { self.frame.len() } } + +impl Index for LagrangeKernelEvaluationFrame { + type Output = E; + + fn index(&self, index: usize) -> &Self::Output { + &self.frame[index] + } +} + +impl IndexMut for LagrangeKernelEvaluationFrame { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.frame[index] + } +} diff --git a/air/src/air/logup_gkr/lagrange/transition.rs b/air/src/air/logup_gkr/lagrange/transition.rs index 18bdfa9be..5f5b110e6 100644 --- a/air/src/air/logup_gkr/lagrange/transition.rs +++ b/air/src/air/logup_gkr/lagrange/transition.rs @@ -43,6 +43,11 @@ impl LagrangeKernelTransitionConstraints { } } + /// Returns the constraint composition coefficients for the Lagrange kernel transition constraints. + pub fn lagrange_constraint_coefficients(&self) -> &[E] { + &self.lagrange_constraint_coefficients + } + /// Evaluates the numerator of the `constraint_idx`th transition constraint. pub fn evaluate_ith_numerator( &self, @@ -54,14 +59,12 @@ impl LagrangeKernelTransitionConstraints { F: FieldElement, E: ExtensionOf, { - let c = lagrange_kernel_column_frame.inner(); - let v = c.len() - 1; + let c = lagrange_kernel_column_frame; + let v = c.num_rows() - 1; let r = lagrange_kernel_rand_elements; let k = constraint_idx + 1; - let eval = (r[v - k] * c[0]) - ((E::ONE - r[v - k]) * c[v - k + 1]); - - self.lagrange_constraint_coefficients[constraint_idx].mul_base(eval) + (r[v - k] * c[0]) - ((E::ONE - r[v - k]) * c[v - k + 1]) } /// Evaluates the divisor of the `constraint_idx`th transition constraint. @@ -124,8 +127,8 @@ impl LagrangeKernelTransitionConstraints { let log2_trace_len = lagrange_kernel_column_frame.num_rows() - 1; let mut transition_evals = vec![E::ZERO; log2_trace_len]; - let c = lagrange_kernel_column_frame.inner(); - let v = c.len() - 1; + let c = lagrange_kernel_column_frame; + let v = c.num_rows() - 1; let r = lagrange_kernel_rand_elements; for k in 1..v + 1 { diff --git a/air/src/proof/ood_frame.rs b/air/src/proof/ood_frame.rs index feab1b260..9ae017094 100644 --- a/air/src/proof/ood_frame.rs +++ b/air/src/proof/ood_frame.rs @@ -131,7 +131,7 @@ impl OodFrame { let lagrange_kernel_frame = if lagrange_kernel_frame_size > 0 { let lagrange_kernel_trace = reader.read_many(lagrange_kernel_frame_size)?; - Some(LagrangeKernelEvaluationFrame::new(lagrange_kernel_trace)) + Some(LagrangeKernelEvaluationFrame::with_values(lagrange_kernel_trace)) } else { None }; diff --git a/prover/benches/logup_gkr.rs b/prover/benches/logup_gkr.rs index 0ad845f4e..6e67eddc2 100644 --- a/prover/benches/logup_gkr.rs +++ b/prover/benches/logup_gkr.rs @@ -17,10 +17,10 @@ use winter_prover::{ crypto::{hashers::Blake3_256, DefaultRandomCoin}, math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, matrix::ColMatrix, - DefaultConstraintEvaluator, DefaultTraceLde, Prover, StarkDomain, Trace, TracePolyTable, + DefaultTraceLde, LogUpGkrConstraintEvaluator, Prover, StarkDomain, Trace, TracePolyTable, }; -const TRACE_LENS: [usize; 2] = [2_usize.pow(20), 2_usize.pow(21)]; +const TRACE_LENS: [usize; 2] = [2_usize.pow(18), 2_usize.pow(20)]; const AUX_TRACE_WIDTH: usize = 2; /// Simple end-to-end benchmark for LogUp-GKR. @@ -310,7 +310,7 @@ impl Prover for LogUpGkrSimpleProver { type TraceLde> = DefaultTraceLde; type ConstraintEvaluator<'a, E: FieldElement> = - DefaultConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; + LogUpGkrConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { } @@ -340,7 +340,7 @@ impl Prover for LogUpGkrSimpleProver { where E: math::FieldElement, { - DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients) + LogUpGkrConstraintEvaluator::new(air, aux_rand_elements.unwrap(), composition_coefficients) } fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix diff --git a/prover/src/constraints/evaluation_table.rs b/prover/src/constraints/evaluation_table.rs index 08c9167f2..5ec4f92ee 100644 --- a/prover/src/constraints/evaluation_table.rs +++ b/prover/src/constraints/evaluation_table.rs @@ -46,8 +46,9 @@ impl<'a, E: FieldElement> ConstraintEvaluationTable<'a, E> { pub fn new( domain: &'a StarkDomain, divisors: Vec>, + logup_gkr_enabled: bool, ) -> Self { - let num_columns = divisors.len(); + let num_columns = divisors.len() + logup_gkr_enabled as usize; let num_rows = domain.ce_domain_size(); ConstraintEvaluationTable { evaluations: uninit_matrix(num_columns, num_rows), @@ -64,8 +65,9 @@ impl<'a, E: FieldElement> ConstraintEvaluationTable<'a, E> { domain: &'a StarkDomain, divisors: Vec>, transition_constraints: &TransitionConstraints, + logup_gkr_enabled: bool, ) -> Self { - let num_columns = divisors.len(); + let num_columns = divisors.len() + logup_gkr_enabled as usize; let num_rows = domain.ce_domain_size(); let num_tm_columns = transition_constraints.num_main_constraints(); let num_ta_columns = transition_constraints.num_aux_constraints(); @@ -161,13 +163,28 @@ impl<'a, E: FieldElement> ConstraintEvaluationTable<'a, E> { /// Divides constraint evaluation columns by their respective divisor (in evaluation form) and /// combines the results into a single column. pub fn combine(self) -> Vec { - // allocate memory for the combined polynomial - let mut combined_poly = vec![E::ZERO; self.num_rows()]; + // when LogUp-GKR is enabled, the last column contains the constraint evaluations of + // the Lagrange kernel column and the s-column. These evaluations were already divided by + // their respective divisors, and hence we just have to add them to `combined_poly`. + let mut combined_poly = if self.evaluations.len() != self.divisors.len() { + // allocate memory for the combined polynomial + let mut combined_poly = unsafe { uninit_vector(self.num_rows()) }; + + iter_mut!(combined_poly) + .enumerate() + .for_each(|(i, row)| *row = self.evaluations[self.divisors.len()][i]); + combined_poly + } else { + vec![E::ZERO; self.num_rows()] + }; // iterate over all columns of the constraint evaluation table, divide each column // by the evaluations of its corresponding divisor, and add all resulting evaluations - // together into a single vector - for (column, divisor) in self.evaluations.into_iter().zip(self.divisors.iter()) { + // together into a single vector. When LogUp-GKR is enabled, we skip the last two columns + // of the evaluation table as these were already handled above. + for (column, divisor) in + self.evaluations.into_iter().take(self.divisors.len()).zip(self.divisors.iter()) + { // divide the column by the divisor and accumulate the result into combined_poly acc_column(column, divisor, self.domain, &mut combined_poly); } diff --git a/prover/src/constraints/evaluator/default.rs b/prover/src/constraints/evaluator/default.rs index 4373494f9..1a5da1b4b 100644 --- a/prover/src/constraints/evaluator/default.rs +++ b/prover/src/constraints/evaluator/default.rs @@ -13,9 +13,8 @@ use utils::iter_mut; use utils::{iterators::*, rayon}; use super::{ - super::EvaluationTableFragment, logup_gkr::LogUpGkrConstraintsEvaluator, BoundaryConstraints, - CompositionPolyTrace, ConstraintEvaluationTable, ConstraintEvaluator, PeriodicValueTable, - StarkDomain, TraceLde, + super::EvaluationTableFragment, BoundaryConstraints, CompositionPolyTrace, + ConstraintEvaluationTable, ConstraintEvaluator, PeriodicValueTable, StarkDomain, TraceLde, }; // CONSTANTS @@ -40,7 +39,6 @@ pub struct DefaultConstraintEvaluator<'a, A: Air, E: FieldElement, transition_constraints: TransitionConstraints, - logup_gkr_constraints_evaluator: Option>, aux_rand_elements: Option>, periodic_values: PeriodicValueTable, } @@ -80,10 +78,14 @@ where // memory to hold all transition constraint evaluations (before they are merged into a // single value) so that we can check their degrees later #[cfg(not(debug_assertions))] - let mut evaluation_table = ConstraintEvaluationTable::::new(domain, divisors); + let mut evaluation_table = ConstraintEvaluationTable::::new(domain, divisors, false); #[cfg(debug_assertions)] - let mut evaluation_table = - ConstraintEvaluationTable::::new(domain, divisors, &self.transition_constraints); + let mut evaluation_table = ConstraintEvaluationTable::::new( + domain, + divisors, + &self.transition_constraints, + false, + ); // when `concurrent` feature is enabled, break the evaluation table into multiple fragments // to evaluate them into multiple threads; unless the constraint evaluation domain is small, @@ -116,16 +118,7 @@ where #[cfg(debug_assertions)] evaluation_table.validate_transition_degrees(); - // combine all constraint evaluations into a single column, including the evaluations of the - // LogUp-GKR constraints (if present) - let combined_evaluations = { - let mut constraints_evaluations = evaluation_table.combine(); - self.evaluate_logup_gkr_constraints(trace, domain, &mut constraints_evaluations); - - constraints_evaluations - }; - - CompositionPolyTrace::new(combined_evaluations) + CompositionPolyTrace::new(evaluation_table.combine()) } } @@ -143,6 +136,11 @@ where aux_rand_elements: Option>, composition_coefficients: ConstraintCompositionCoefficients, ) -> Self { + assert!( + !air.context().logup_gkr_enabled(), + "evaluating LogUp-GKR constraints is not supported in `DefaultConstraintEvaluator`" + ); + // build transition constraint groups; these will be used to compose transition constraint // evaluations let transition_constraints = @@ -158,31 +156,10 @@ where &composition_coefficients.boundary, ); - let logup_gkr_constraints_evaluator = if air.context().logup_gkr_enabled() { - let aux_rand_elements = - aux_rand_elements.as_ref().expect("expected aux rand elements to be present"); - - Some(LogUpGkrConstraintsEvaluator::new( - air, - aux_rand_elements - .gkr_data() - .expect("expected LogUp-GKR randomness to be present"), - composition_coefficients - .lagrange - .expect("expected Lagrange kernel composition coefficients to be present"), - composition_coefficients - .s_col - .expect("expected s-column composition coefficient to be present"), - )) - } else { - None - }; - DefaultConstraintEvaluator { air, boundary_constraints, transition_constraints, - logup_gkr_constraints_evaluator, aux_rand_elements, periodic_values, } @@ -298,29 +275,6 @@ where } } - /// If present, evaluates the LogUp-GKR constraints over the constraint evaluation domain. - /// The evaluation of each constraint (both boundary and transition) is divided by its divisor, - /// multiplied by its composition coefficient, the result of which is added to - /// `combined_evaluations_accumulator`. - /// - /// Specifically, `combined_evaluations_accumulator` is a buffer whose length is the size of the - /// constraint evaluation domain, where each index contains combined evaluations of other - /// constraints in the system. - fn evaluate_logup_gkr_constraints>( - &self, - trace: &T, - domain: &StarkDomain, - combined_evaluations_accumulator: &mut [E], - ) { - if let Some(ref logup_gkr_constraints_evaluator) = self.logup_gkr_constraints_evaluator { - logup_gkr_constraints_evaluator.evaluate_constraints( - trace, - domain, - combined_evaluations_accumulator, - ) - } - } - // TRANSITION CONSTRAINT EVALUATORS // -------------------------------------------------------------------------------------------- diff --git a/prover/src/constraints/evaluator/logup_gkr.rs b/prover/src/constraints/evaluator/logup_gkr.rs index cc7390b73..a729c7545 100644 --- a/prover/src/constraints/evaluator/logup_gkr.rs +++ b/prover/src/constraints/evaluator/logup_gkr.rs @@ -6,34 +6,43 @@ use alloc::vec::Vec; use air::{ - Air, EvaluationFrame, GkrData, LagrangeConstraintsCompositionCoefficients, - LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LogUpGkrEvaluator, + Air, GkrData, LagrangeConstraintsCompositionCoefficients, LagrangeKernelConstraints, + LogUpGkrEvaluator, }; use math::{batch_inversion, FieldElement}; -use crate::{StarkDomain, TraceLde}; +use crate::StarkDomain; /// Contains a specific strategy for evaluating the Lagrange kernel and s-column boundary and /// transition constraints. -pub struct LogUpGkrConstraintsEvaluator<'a, E: FieldElement, A: Air> { - air: &'a A, - lagrange_kernel_constraints: LagrangeKernelConstraints, - gkr_data: GkrData, - s_col_composition_coefficient: E, +pub struct LogUpGkrConstraintsEvaluator { + pub(crate) lagrange_kernel_constraints: LagrangeKernelConstraints, + pub(crate) gkr_data: GkrData, + pub(crate) s_col_composition_coefficient: E, + pub(crate) s_col_idx: usize, + pub(crate) l_col_idx: usize, + pub(crate) mean: E, } -impl<'a, E, A> LogUpGkrConstraintsEvaluator<'a, E, A> +impl LogUpGkrConstraintsEvaluator where E: FieldElement, - A: Air, { /// Constructs a new [`LogUpGkrConstraintsEvaluator`]. - pub fn new( - air: &'a A, + pub fn new>( + air: &A, gkr_data: GkrData, lagrange_composition_coefficients: LagrangeConstraintsCompositionCoefficients, s_col_composition_coefficient: E, ) -> Self { + let trace_info = air.trace_info(); + let s_col_idx = trace_info.s_column_idx().expect("S-column should be present"); + let l_col_idx = trace_info + .lagrange_kernel_column_idx() + .expect("Lagrange kernel should be present"); + + let c = gkr_data.compute_batched_claim(); + let mean = c / E::from(E::BaseField::from(trace_info.length() as u32)); Self { lagrange_kernel_constraints: air .get_logup_gkr_evaluator() @@ -41,148 +50,22 @@ where lagrange_composition_coefficients, gkr_data.lagrange_kernel_rand_elements(), ), - air, gkr_data, s_col_composition_coefficient, + s_col_idx, + l_col_idx, + mean, } } - - /// Evaluates the transition and boundary constraints. Specifically, the constraint evaluations - /// are divided by their corresponding divisors, and the resulting terms are linearly combined - /// using the constraint composition coefficients. - /// - /// Writes the evaluations in `combined_evaluations_acc` at the corresponding (constraint - /// evaluation) domain index. - pub fn evaluate_constraints( - &self, - trace: &T, - domain: &StarkDomain, - combined_evaluations_acc: &mut [E], - ) where - T: TraceLde, - { - let lde_shift = domain.ce_to_lde_blowup().trailing_zeros(); - let trans_constraints_divisors = LagrangeKernelTransitionConstraintsDivisor::new( - self.lagrange_kernel_constraints.transition.num_constraints(), - domain, - ); - let boundary_divisors_inv = self.compute_boundary_divisors_inv(domain); - - let mut lagrange_frame = LagrangeKernelEvaluationFrame::new_empty(); - - let evaluator = self.air.get_logup_gkr_evaluator(); - let s_col_constraint_divisor = compute_s_col_divisor::(domain, self.air.trace_length()); - let s_col_idx = trace.trace_info().s_column_idx().expect("S-column should be present"); - let l_col_idx = trace - .trace_info() - .lagrange_kernel_column_idx() - .expect("Lagrange kernel should be present"); - let mut main_frame = EvaluationFrame::new(trace.trace_info().main_segment_width()); - let mut aux_frame = EvaluationFrame::new(trace.trace_info().aux_segment_width()); - - let c = self.gkr_data.compute_batched_claim(); - let mean = c / E::from(E::BaseField::from(trace.trace_info().length() as u32)); - let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; - - for step in 0..domain.ce_domain_size() { - // compute Lagrange kernel frame - trace.read_lagrange_kernel_frame_into( - step << lde_shift, - l_col_idx, - &mut lagrange_frame, - ); - - // Compute the combined transition and boundary constraints evaluations for this row - let lagrange_combined_evaluations = { - let mut combined_evaluations = E::ZERO; - - // combine transition constraints - for trans_constraint_idx in - 0..self.lagrange_kernel_constraints.transition.num_constraints() - { - let numerator = - self.lagrange_kernel_constraints.transition.evaluate_ith_numerator( - &lagrange_frame, - &self.gkr_data.lagrange_kernel_eval_point, - trans_constraint_idx, - ); - let inv_divisor = trans_constraints_divisors - .get_inverse_divisor_eval(trans_constraint_idx, step); - - combined_evaluations += numerator * inv_divisor; - } - - // combine boundary constraints - { - let boundary_numerator = self - .lagrange_kernel_constraints - .boundary - .evaluate_numerator_at(&lagrange_frame); - - combined_evaluations += boundary_numerator * boundary_divisors_inv[step]; - } - - combined_evaluations - }; - - // compute and combine the transition constraints for the s-column. - // The s-column implements the cohomological sum-check argument of [1] and - // the constraint we enfore is exactly Eq (4) in Lemma 1 in [1]. - // - // [1]: https://eprint.iacr.org/2021/930 - let s_col_combined_evaluation = { - trace.read_main_trace_frame_into(step << lde_shift, &mut main_frame); - trace.read_aux_trace_frame_into(step << lde_shift, &mut aux_frame); - - let l_cur = aux_frame.current()[l_col_idx]; - let s_cur = aux_frame.current()[s_col_idx]; - let s_nxt = aux_frame.next()[s_col_idx]; - - evaluator.build_query(&main_frame, &mut query); - let batched_query = self.gkr_data.compute_batched_query(&query); - - let rhs = s_cur - mean + batched_query * l_cur; - let lhs = s_nxt; - - let divisor_at_step = - s_col_constraint_divisor[step % (domain.trace_to_ce_blowup())]; - - (rhs - lhs) * self.s_col_composition_coefficient.mul_base(divisor_at_step) - }; - - combined_evaluations_acc[step] += - lagrange_combined_evaluations + s_col_combined_evaluation; - } - } - - // HELPERS - // --------------------------------------------------------------------------------------------- - - /// Computes the inverse boundary divisor at every point of the constraint evaluation domain. - /// That is, returns a vector of the form `[1 / div_0, ..., 1 / div_n]`, where `div_i` is the - /// divisor for the Lagrange kernel boundary constraint at the i'th row of the constraint - /// evaluation domain. - fn compute_boundary_divisors_inv(&self, domain: &StarkDomain) -> Vec { - let mut boundary_denominator_evals = Vec::with_capacity(domain.ce_domain_size()); - for step in 0..domain.ce_domain_size() { - let domain_point = domain.get_ce_x_at(step); - let boundary_denominator = self - .lagrange_kernel_constraints - .boundary - .evaluate_denominator_at(domain_point.into()); - boundary_denominator_evals.push(boundary_denominator); - } - - batch_inversion(&boundary_denominator_evals) - } } -/// Holds all the transition constraint inverse divisor evaluations over the constraint evaluation -/// domain. +/// Holds all the transition and boundary constraint inverse divisor evaluations over +/// the constraint evaluation domain for both the Lagrange kernel as well the s-column. /// -/// [`LagrangeKernelTransitionConstraintsDivisor`] takes advantage of some structure in the -/// divisors' evaluations. Recall that the divisor for the i'th transition constraint is `x^(2^i) - -/// 1`. When substituting `x` for each value of the constraint evaluation domain, for constraints +/// [`LogUpGkrConstraintsDivisors`] takes advantage of some structure in the divisors' +/// evaluations for transition constraints. +/// Recall that the divisor for the i'th transition constraint is `x^(2^i) - 1`. +/// When substituting `x` for each value of the constraint evaluation domain, for constraints /// `i>0`, the divisor evaluations "wrap-around" such that some values repeat. For example, /// /// i=0: no repetitions @@ -192,8 +75,17 @@ where /// ... /// Therefore, we only compute the non-repeating section of the buffer in each iteration, and index /// into it accordingly. -struct LagrangeKernelTransitionConstraintsDivisor { - divisor_evals_inv: Vec, +/// +/// Note that instead of storing `1 / div` for Lagrange and s-column transition and boundary +/// constraints, we store instead `c / div` where `c` is the constraint composition coefficient +/// associated to divisor `div`. We call `c / div` constraint evaluation multipliers or just +/// constraint multipliers. +pub(crate) struct LogUpGkrConstraintsDivisors { + lagrange_transition_multipliers: Vec, + + lagrange_boundary_multipliers: Vec, + + s_col_transition_multipliers: Vec, // Precompute the indices into `divisors_evals_inv` of the slices that correspond to each // transition constraint. @@ -206,23 +98,50 @@ struct LagrangeKernelTransitionConstraintsDivisor { slice_indices_precomputes: Vec, } -impl LagrangeKernelTransitionConstraintsDivisor { +impl LogUpGkrConstraintsDivisors { pub fn new( - num_lagrange_transition_constraints: usize, + logup_gkr_constraints: &LogUpGkrConstraintsEvaluator, domain: &StarkDomain, ) -> Self { - let divisor_evals_inv = { + let num_lagrange_transition_constraints = + logup_gkr_constraints.lagrange_kernel_constraints.transition.num_constraints(); + + // collect all constraint composition coefficient in order to optimize inversion + let mut lagrange_transition_cc = logup_gkr_constraints + .lagrange_kernel_constraints + .transition + .lagrange_constraint_coefficients() + .to_vec(); + let lagrange_boundary_cc = logup_gkr_constraints + .lagrange_kernel_constraints + .boundary + .constraint_composition_coefficient(); + let s_col_cc = logup_gkr_constraints.s_col_composition_coefficient; + + lagrange_transition_cc.push(lagrange_boundary_cc); + lagrange_transition_cc.push(s_col_cc); + + // batch invert + let constraint_composition_coefficients = lagrange_transition_cc; + let constraint_composition_coefficients_inv = + batch_inversion(&constraint_composition_coefficients); + + let lagrange_cc_inv = + &constraint_composition_coefficients_inv[..num_lagrange_transition_constraints]; + let lagrange_transition_multipliers = { let divisor_evaluator = TransitionDivisorEvaluator::::new( num_lagrange_transition_constraints, domain.offset(), ); // The number of divisor evaluations is - // `ce_domain_size + ce_domain_size/2 + ce_domain_size/4 + ... + ce_domain_size/(log(ce_domain_size)-1)`, - // which is slightly smaller than `ce_domain_size * 2` - let mut divisor_evals: Vec = Vec::with_capacity(domain.ce_domain_size() * 2); + // `ce_domain_size + ce_domain_size/2 + ce_domain_size/4 + ... + + // ce_domain_size/(log(ce_domain_size)-1)`, + // which is slightly smaller than `ce_domain_size * 2`. + // This is also the number of multipliers `c / div` for Lagrange transition constraints + let mut multipliers: Vec = Vec::with_capacity(domain.ce_domain_size() * 2); - for trans_constraint_idx in 0..num_lagrange_transition_constraints { + for (trans_constraint_idx, cc_inv) in lagrange_cc_inv.iter().enumerate() { let num_non_repeating_denoms = domain.ce_domain_size() / 2_usize.pow(trans_constraint_idx as u32); @@ -230,13 +149,38 @@ impl LagrangeKernelTransitionConstraintsDivisor { let divisor_eval = divisor_evaluator.evaluate_ith_divisor(trans_constraint_idx, domain, step); - divisor_evals.push(divisor_eval.into()); + multipliers.push(cc_inv.mul_base(divisor_eval)); } } - batch_inversion(&divisor_evals) + batch_inversion(&multipliers) }; + // computes the inverse boundary divisor multiplier by the corresponding constraint + // composition at every point of the constraint evaluation domain. + // That is, returns a vector of the form `[c / div_0, ..., c / div_n]`, where `div_i` is the + // divisor for the Lagrange kernel boundary constraint against the first row at the i'th row + // of the constraint evaluation domain, and `c` is the constraint evaluation coefficient. + let lagrange_boundary_multipliers = { + let mut multipliers = Vec::with_capacity(domain.ce_domain_size()); + for step in 0..domain.ce_domain_size() { + let domain_point = domain.get_ce_x_at(step); + let boundary_denominator = domain_point - E::BaseField::ONE; + let multiplier = constraint_composition_coefficients_inv + [num_lagrange_transition_constraints] + .mul_base(boundary_denominator); + multipliers.push(multiplier); + } + + batch_inversion(&multipliers) + }; + + // compute the divisors for the s-column transition constraint + let s_col_transition_multipliers = compute_s_col_multipliers( + domain, + constraint_composition_coefficients_inv[num_lagrange_transition_constraints + 1], + ); + let slice_indices_precomputes = { let num_indices = num_lagrange_transition_constraints + 1; let mut slice_indices_precomputes = Vec::with_capacity(num_indices); @@ -255,30 +199,50 @@ impl LagrangeKernelTransitionConstraintsDivisor { }; Self { - divisor_evals_inv, + lagrange_transition_multipliers, + lagrange_boundary_multipliers, slice_indices_precomputes, + s_col_transition_multipliers, } } - /// Returns the evaluation `1 / divisor`, where `divisor` is the divisor for the given - /// transition constraint, at the given row of the constraint evaluation domain - pub fn get_inverse_divisor_eval(&self, trans_constraint_idx: usize, row_idx: usize) -> E { - let inv_divisors_slice_for_constraint = - self.get_transition_constraint_slice(trans_constraint_idx); + /// Returns the evaluation `c / divisor`, where `divisor` is the divisor for the given + /// Lagrange kernel transition constraint, at the given row of the constraint evaluation domain + /// and `c` is the corresponding constraint composition coefficient. + pub fn get_lagrange_transition_multiplier( + &self, + trans_constraint_idx: usize, + row_idx: usize, + ) -> E { + let multipliers_slice = self.get_lagrange_transition_constraint_slice(trans_constraint_idx); + + multipliers_slice[row_idx % multipliers_slice.len()] + } + + /// Returns the evaluation `c / divisor`, where `divisor` runs over all Lagrange kernel + /// boundary constraint divisors at the given row of the constraint evaluation domain and `c` + /// is the corresponding constraint composition coefficient. + pub fn get_lagrange_boundary_multiplier(&self, row_idx: usize) -> E { + self.lagrange_boundary_multipliers[row_idx % self.lagrange_boundary_multipliers.len()] + } - inv_divisors_slice_for_constraint[row_idx % inv_divisors_slice_for_constraint.len()] + /// Returns the evaluation `c / divisor`, where `divisor` is the divisor for the s-column + /// transition constraint, at the given row of the constraint evaluation domain and `c` is + /// the corresponding constraint composition coefficient. + pub fn get_s_col_transition_multiplier(&self, row_idx: usize) -> E { + self.s_col_transition_multipliers[row_idx % (self.s_col_transition_multipliers.len())] } // HELPERS // --------------------------------------------------------------------------------------------- - /// Returns a slice containing all the inverse divisor evaluations for the given transition - /// constraint. - fn get_transition_constraint_slice(&self, trans_constraint_idx: usize) -> &[E] { + /// Returns a slice containing all the multipliers evaluations' for the given Lagrange + /// transition constraint. + fn get_lagrange_transition_constraint_slice(&self, trans_constraint_idx: usize) -> &[E] { let start = self.slice_indices_precomputes[trans_constraint_idx]; let end = self.slice_indices_precomputes[trans_constraint_idx + 1]; - &self.divisor_evals_inv[start..end] + &self.lagrange_transition_multipliers[start..end] } } @@ -342,21 +306,21 @@ impl TransitionDivisorEvaluator { } } -/// Computes the evaluations of the s-column divisor. +/// Computes the evaluations of the s-column multipliers. /// /// The divisor for the s-column is $X^n - 1$ where $n$ is the trace length. This means that /// we need only compute `ce_blowup` many values and thus only that many exponentiations. -fn compute_s_col_divisor( +fn compute_s_col_multipliers( domain: &StarkDomain, - trace_length: usize, -) -> Vec { - let degree = trace_length as u32; + composition_coef_inv: E, +) -> Vec { + let degree = domain.trace_length() as u32; let mut result = Vec::with_capacity(domain.trace_to_ce_blowup()); for row in 0..domain.trace_to_ce_blowup() { - let x = domain.get_ce_x_at(row).exp(degree.into()) - E::BaseField::ONE; + let divisor = domain.get_ce_x_at(row).exp(degree.into()) - E::BaseField::ONE; - result.push(x); + result.push(composition_coef_inv.mul_base(divisor)); } batch_inversion(&result) } diff --git a/prover/src/constraints/evaluator/logup_gkr_evaluator.rs b/prover/src/constraints/evaluator/logup_gkr_evaluator.rs new file mode 100644 index 000000000..2c26e735b --- /dev/null +++ b/prover/src/constraints/evaluator/logup_gkr_evaluator.rs @@ -0,0 +1,433 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use air::{ + Air, AuxRandElements, ConstraintCompositionCoefficients, EvaluationFrame, + LagrangeKernelEvaluationFrame, LogUpGkrEvaluator, TransitionConstraints, +}; +use math::FieldElement; +use tracing::instrument; +use utils::iter_mut; +#[cfg(feature = "concurrent")] +use utils::{iterators::*, rayon}; + +use super::{ + super::EvaluationTableFragment, + logup_gkr::{LogUpGkrConstraintsDivisors, LogUpGkrConstraintsEvaluator}, + BoundaryConstraints, CompositionPolyTrace, ConstraintEvaluationTable, ConstraintEvaluator, + PeriodicValueTable, StarkDomain, TraceLde, +}; + +// CONSTANTS +// ================================================================================================ + +#[cfg(feature = "concurrent")] +const MIN_CONCURRENT_DOMAIN_SIZE: usize = 8192; + +// DEFAULT CONSTRAINT EVALUATOR +// ================================================================================================ + +/// Default implementation of the [ConstraintEvaluator] trait. +/// +/// This implementation iterates over all evaluation frames of an extended execution trace and +/// evaluates constraints over these frames one-by-one. Constraint evaluations are merged together +/// using random linear combinations and in the end, only a single column is returned. +/// +/// When `concurrent` feature is enabled, the extended execution trace is split into sets of +/// sequential evaluation frames (called fragments), and frames in each fragment are evaluated +/// in separate threads. +pub struct LogUpGkrConstraintEvaluator<'a, A: Air, E: FieldElement> { + air: &'a A, + boundary_constraints: BoundaryConstraints, + transition_constraints: TransitionConstraints, + periodic_values: PeriodicValueTable, + logup_gkr_constraints_evaluator: LogUpGkrConstraintsEvaluator, + aux_rand_elements: AuxRandElements, +} + +impl<'a, A, E> ConstraintEvaluator for LogUpGkrConstraintEvaluator<'a, A, E> +where + A: Air, + E: FieldElement, +{ + type Air = A; + + #[instrument( + skip_all, + name = "evaluate_constraints", + fields( + ce_domain_size = %domain.ce_domain_size() + ) + )] + fn evaluate>( + self, + trace: &T, + domain: &StarkDomain<::BaseField>, + ) -> CompositionPolyTrace { + assert_eq!( + trace.trace_len(), + domain.lde_domain_size(), + "extended trace length is not consistent with evaluation domain" + ); + + // build a list of constraint divisors; currently, all transition constraints have the same + // divisor which we put at the front of the list; boundary constraint divisors are appended + // after that + let mut divisors = vec![self.transition_constraints.divisor().clone()]; + divisors.append(&mut self.boundary_constraints.get_divisors()); + + // build the divisors related to LogUp-GKR + let logup_gkr_constraints_divisors = + LogUpGkrConstraintsDivisors::::new(&self.logup_gkr_constraints_evaluator, domain); + + // allocate space for constraint evaluations; when we are in debug mode, we also allocate + // memory to hold all transition constraint evaluations (before they are merged into a + // single value) so that we can check their degrees later + #[cfg(not(debug_assertions))] + let mut evaluation_table = ConstraintEvaluationTable::::new(domain, divisors, true); + #[cfg(debug_assertions)] + let mut evaluation_table = ConstraintEvaluationTable::::new( + domain, + divisors, + &self.transition_constraints, + true, + ); + + // when `concurrent` feature is enabled, break the evaluation table into multiple fragments + // to evaluate them into multiple threads; unless the constraint evaluation domain is small, + // then don't bother with concurrent evaluation + + #[cfg(not(feature = "concurrent"))] + let num_fragments = 1; + + #[cfg(feature = "concurrent")] + let num_fragments = if domain.ce_domain_size() >= MIN_CONCURRENT_DOMAIN_SIZE { + rayon::current_num_threads().next_power_of_two() + } else { + 1 + }; + + // evaluate constraints for each fragment; if the trace consist of multiple segments + // we evaluate constraints for all segments. otherwise, we evaluate constraints only + // for the main segment. + let mut fragments = evaluation_table.fragments(num_fragments); + iter_mut!(fragments).for_each(|fragment| { + self.evaluate_fragment_full(trace, domain, fragment, &logup_gkr_constraints_divisors); + }); + + // when in debug mode, make sure expected transition constraint degrees align with + // actual degrees we got during constraint evaluation + #[cfg(debug_assertions)] + evaluation_table.validate_transition_degrees(); + + CompositionPolyTrace::new(evaluation_table.combine()) + } +} + +impl<'a, A, E> LogUpGkrConstraintEvaluator<'a, A, E> +where + A: Air, + E: FieldElement, +{ + // CONSTRUCTOR + // -------------------------------------------------------------------------------------------- + /// Returns a new evaluator which can be used to evaluate transition and boundary constraints + /// over extended execution trace. + pub fn new( + air: &'a A, + aux_rand_elements: AuxRandElements, + composition_coefficients: ConstraintCompositionCoefficients, + ) -> Self { + assert!( + air.context().logup_gkr_enabled(), + "`LogUpGkrConstraintEvaluator` can only be used when LogUp-GKR is enabled" + ); + + // build transition constraint groups; these will be used to compose transition constraint + // evaluations + let transition_constraints = + air.get_transition_constraints(&composition_coefficients.transition); + // build periodic value table + let periodic_values = PeriodicValueTable::new(air); + + // build boundary constraint groups; these will be used to evaluate and compose boundary + // constraint evaluations. + let boundary_constraints = BoundaryConstraints::new( + air, + Some(&aux_rand_elements), + &composition_coefficients.boundary, + ); + + let logup_gkr_constraints_evaluator = LogUpGkrConstraintsEvaluator::new( + air, + aux_rand_elements + .gkr_data() + .expect("expected LogUp-GKR randomness to be present"), + composition_coefficients + .lagrange + .expect("expected Lagrange kernel composition coefficients to be present"), + composition_coefficients + .s_col + .expect("expected s-column composition coefficient to be present"), + ); + air.trace_info(); + + Self { + air, + boundary_constraints, + transition_constraints, + logup_gkr_constraints_evaluator, + aux_rand_elements, + periodic_values, + } + } + + // EVALUATION HELPER + // -------------------------------------------------------------------------------------------- + + /// Evaluates constraints for a single fragment of the evaluation table. + /// + /// This evaluates constraints only over all segments of the execution trace (i.e. main segment + /// and all auxiliary segments). + fn evaluate_fragment_full>( + &self, + trace: &T, + domain: &StarkDomain, + fragment: &mut EvaluationTableFragment, + logup_gkr_divisors: &LogUpGkrConstraintsDivisors, + ) { + // initialize buffers to hold trace values and evaluation results at each step + let mut main_frame = EvaluationFrame::new(trace.trace_info().main_segment_width()); + let mut aux_frame = EvaluationFrame::new(trace.trace_info().aux_segment_width()); + let mut tm_evaluations = vec![E::BaseField::ZERO; self.num_main_transition_constraints()]; + let mut ta_evaluations = vec![E::ZERO; self.num_aux_transition_constraints()]; + let mut evaluations = vec![E::ZERO; fragment.num_columns()]; + let mut lagrange_frame = LagrangeKernelEvaluationFrame::new(trace.trace_info().length()); + + let evaluator = self.air.get_logup_gkr_evaluator(); + let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; + + // this will be used to convert steps in constraint evaluation domain to steps in + // LDE domain + let lde_shift = domain.ce_to_lde_blowup().trailing_zeros(); + + for i in 0..fragment.num_rows() { + let step = i + fragment.offset(); + + // read both the main and the auxiliary evaluation frames from the trace + trace.read_main_trace_frame_into(step << lde_shift, &mut main_frame); + trace.read_lagrange_kernel_frame_into( + step << lde_shift, + self.logup_gkr_constraints_evaluator.l_col_idx, + &mut lagrange_frame, + ); + trace.read_aux_trace_frame_into(step << lde_shift, &mut aux_frame); + + // evaluate transition constraints and save the merged result the first slot of the + // evaluations buffer; we evaluate and compose constraints in the same function, we + // can just add up the results of evaluating main and auxiliary constraints. + evaluations[0] = self.evaluate_main_transition(&main_frame, step, &mut tm_evaluations); + evaluations[0] += + self.evaluate_aux_transition(&main_frame, &aux_frame, step, &mut ta_evaluations); + + // when in debug mode, save transition constraint evaluations + #[cfg(debug_assertions)] + fragment.update_transition_evaluations(i, &tm_evaluations, &ta_evaluations); + + // evaluate Lagrange kernel constraints and assign them to the last column + *evaluations.last_mut().expect("should contain at least one entry") = self + .evaluate_s_column_transition( + &evaluator, + &main_frame, + &aux_frame, + &mut query, + logup_gkr_divisors.get_s_col_transition_multiplier(step), + ); + // evaluate s-column constraints and add them to the last column + *evaluations.last_mut().expect("should contain at least one entry") += + self.evaluate_lagrange_transition(&lagrange_frame, step, logup_gkr_divisors); + + // evaluate boundary constraints; the results go into remaining slots of the + // evaluations buffer + let main_state = main_frame.current(); + let aux_state = aux_frame.current(); + let limit = evaluations.len() - 1; + self.boundary_constraints.evaluate_all( + main_state, + aux_state, + domain, + step, + &mut evaluations[1..limit], + ); + + // record the result in the evaluation table + fragment.update_row(i, &evaluations); + } + } + + // TRANSITION CONSTRAINT EVALUATOR + // -------------------------------------------------------------------------------------------- + + /// Evaluates transition constraints of the main execution trace at the specified step of the + /// constraint evaluation domain. + /// + /// `x` is the corresponding domain value at the specified step. That is, x = s * g^step, + /// where g is the generator of the constraint evaluation domain, and s is the domain offset. + fn evaluate_main_transition( + &self, + main_frame: &EvaluationFrame, + step: usize, + evaluations: &mut [E::BaseField], + ) -> E { + // TODO: use a more efficient way to zero out memory + evaluations.fill(E::BaseField::ZERO); + + // get periodic values at the evaluation step + let periodic_values = self.periodic_values.get_row(step); + + // evaluate transition constraints over the main segment of the execution trace and save + // the results into evaluations buffer + self.air.evaluate_transition(main_frame, periodic_values, evaluations); + + // merge transition constraint evaluations into a single value and return it; + // we can do this here because all transition constraints have the same divisor. + evaluations + .iter() + .zip(self.transition_constraints.main_constraint_coef().iter()) + .fold(E::ZERO, |acc, (&const_eval, &coef)| acc + coef.mul_base(const_eval)) + } + + /// Evaluates all transition constraints (i.e., for main and the auxiliary trace segment) at the + /// specified step of the constraint evaluation domain. + /// + /// `x` is the corresponding domain value at the specified step. That is, x = s * g^step, + /// where g is the generator of the constraint evaluation domain, and s is the domain offset. + fn evaluate_aux_transition( + &self, + main_frame: &EvaluationFrame, + aux_frame: &EvaluationFrame, + step: usize, + evaluations: &mut [E], + ) -> E { + // TODO: use a more efficient way to zero out memory + evaluations.fill(E::ZERO); + + // get periodic values at the evaluation step + let periodic_values = self.periodic_values.get_row(step); + + // evaluate transition constraints over the auxiliary trace segment and save the results into + // evaluations buffer + self.air.evaluate_aux_transition( + main_frame, + aux_frame, + periodic_values, + &self.aux_rand_elements, + evaluations, + ); + + // merge transition constraint evaluations into a single value and return it; + // we can do this here because all transition constraints have the same divisor. + let evaluation = evaluations + .iter() + .zip(self.transition_constraints.aux_constraint_coef().iter()) + .fold(E::ZERO, |acc, (&const_eval, &coef)| acc + coef * const_eval); + + evaluation + } + + /// Computes the transition and boundary constraints for the Lagrange kernel. + fn evaluate_lagrange_transition( + &self, + lagrange_frame: &LagrangeKernelEvaluationFrame, + step: usize, + constraints_divisors: &LogUpGkrConstraintsDivisors, + ) -> E { + // Compute the combined transition and boundary constraints evaluations for this row + + let mut combined_evaluations = E::ZERO; + + // combine transition constraints + for trans_constraint_idx in 0..self + .logup_gkr_constraints_evaluator + .lagrange_kernel_constraints + .transition + .num_constraints() + { + let numerator = self + .logup_gkr_constraints_evaluator + .lagrange_kernel_constraints + .transition + .evaluate_ith_numerator( + lagrange_frame, + &self.logup_gkr_constraints_evaluator.gkr_data.lagrange_kernel_eval_point, + trans_constraint_idx, + ); + let multiplier = + constraints_divisors.get_lagrange_transition_multiplier(trans_constraint_idx, step); + + combined_evaluations += numerator * multiplier; + } + + // combine boundary constraints + { + let boundary_numerator = self + .logup_gkr_constraints_evaluator + .lagrange_kernel_constraints + .boundary + .evaluate_numerator_at(lagrange_frame); + + combined_evaluations += + boundary_numerator * constraints_divisors.get_lagrange_boundary_multiplier(step); + } + + combined_evaluations + } + + /// Computes the transition constraints for the s-column. + /// + /// The s-column implements the cohomological sum-check argument of [1] and + /// the constraint we enfore is exactly Eq (4) in Lemma 1 in [1]. + /// + /// [1]: https://eprint.iacr.org/2021/930 + fn evaluate_s_column_transition( + &self, + evaluator: &impl LogUpGkrEvaluator, + main_frame: &EvaluationFrame, + aux_frame: &EvaluationFrame, + query: &mut [E::BaseField], + multiplier: E, + ) -> E { + let l_col_idx = self.logup_gkr_constraints_evaluator.l_col_idx; + let s_col_idx = self.logup_gkr_constraints_evaluator.s_col_idx; + let mean = self.logup_gkr_constraints_evaluator.mean; + + let l_cur = aux_frame.current()[l_col_idx]; + let s_cur = aux_frame.current()[s_col_idx]; + let s_nxt = aux_frame.next()[s_col_idx]; + + evaluator.build_query(main_frame, query); + let batched_query = + self.logup_gkr_constraints_evaluator.gkr_data.compute_batched_query(query); + + let rhs = s_cur - mean + batched_query * l_cur; + let lhs = s_nxt; + + (rhs - lhs) * multiplier + } + + // ACCESSORS + // -------------------------------------------------------------------------------------------- + + /// Returns the number of transition constraints applied against the main segment of the + /// execution trace. + fn num_main_transition_constraints(&self) -> usize { + self.transition_constraints.num_main_constraints() + } + + /// Returns the number of transition constraints applied against the auxiliary trace segment. + fn num_aux_transition_constraints(&self) -> usize { + self.transition_constraints.num_aux_constraints() + } +} diff --git a/prover/src/constraints/evaluator/mod.rs b/prover/src/constraints/evaluator/mod.rs index 0ff6916f8..ce0488a57 100644 --- a/prover/src/constraints/evaluator/mod.rs +++ b/prover/src/constraints/evaluator/mod.rs @@ -16,6 +16,9 @@ use boundary::BoundaryConstraints; mod logup_gkr; +mod logup_gkr_evaluator; +pub use logup_gkr_evaluator::LogUpGkrConstraintEvaluator; + mod periodic_table; use periodic_table::PeriodicValueTable; diff --git a/prover/src/constraints/mod.rs b/prover/src/constraints/mod.rs index 566065f0f..9cf84c3dd 100644 --- a/prover/src/constraints/mod.rs +++ b/prover/src/constraints/mod.rs @@ -6,7 +6,7 @@ use super::{ColMatrix, ConstraintDivisor, RowMatrix, StarkDomain}; mod evaluator; -pub use evaluator::{ConstraintEvaluator, DefaultConstraintEvaluator}; +pub use evaluator::{ConstraintEvaluator, DefaultConstraintEvaluator, LogUpGkrConstraintEvaluator}; mod composition_poly; pub use composition_poly::{CompositionPoly, CompositionPolyTrace}; diff --git a/prover/src/lib.rs b/prover/src/lib.rs index b62df14d8..e15028171 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -74,7 +74,7 @@ use matrix::{ColMatrix, RowMatrix}; mod constraints; pub use constraints::{ CompositionPoly, CompositionPolyTrace, ConstraintCommitment, ConstraintEvaluator, - DefaultConstraintEvaluator, + DefaultConstraintEvaluator, LogUpGkrConstraintEvaluator, }; mod composer; diff --git a/prover/src/trace/mod.rs b/prover/src/trace/mod.rs index 2b2e89a9e..08fb49a2a 100644 --- a/prover/src/trace/mod.rs +++ b/prover/src/trace/mod.rs @@ -56,7 +56,7 @@ pub struct AuxTraceWithMetadata { /// implementation supports concurrent trace generation and should be sufficient in most /// situations. However, if functionality provided by [TraceTable] is not sufficient, uses can /// provide custom implementations of the [Trace] trait which better suit their needs. -pub trait Trace: Sized { +pub trait Trace: Sized + Sync { /// Base field for this execution trace. /// /// All cells of this execution trace contain values which are elements in this field. diff --git a/prover/src/trace/trace_lde/default/mod.rs b/prover/src/trace/trace_lde/default/mod.rs index e06839d53..2cb177bc5 100644 --- a/prover/src/trace/trace_lde/default/mod.rs +++ b/prover/src/trace/trace_lde/default/mod.rs @@ -195,20 +195,17 @@ where lagrange_kernel_aux_column_idx: usize, frame: &mut LagrangeKernelEvaluationFrame, ) { - let frame = frame.frame_mut(); - frame.truncate(0); - let aux_segment = self.aux_segment_lde.as_ref().expect("expected aux segment to be present"); - frame.push(aux_segment.get(lagrange_kernel_aux_column_idx, lde_step)); + frame[0] = aux_segment.get(lagrange_kernel_aux_column_idx, lde_step); let frame_length = self.trace_info.length().ilog2() as usize + 1; for i in 0..frame_length - 1 { let shift = self.blowup() * (1 << i); let next_lde_step = (lde_step + shift) % self.trace_len(); - frame.push(aux_segment.get(lagrange_kernel_aux_column_idx, next_lde_step)); + frame[i + 1] = aux_segment.get(lagrange_kernel_aux_column_idx, next_lde_step); } } diff --git a/winterfell/src/lib.rs b/winterfell/src/lib.rs index c3da2fb6a..621796864 100644 --- a/winterfell/src/lib.rs +++ b/winterfell/src/lib.rs @@ -596,8 +596,9 @@ pub use prover::{ BoundaryConstraint, BoundaryConstraintGroup, CompositionPolyTrace, ConstraintCompositionCoefficients, ConstraintDivisor, ConstraintEvaluator, DeepCompositionCoefficients, DefaultConstraintEvaluator, DefaultTraceLde, EvaluationFrame, - FieldExtension, Proof, ProofOptions, Prover, ProverError, StarkDomain, Trace, TraceInfo, - TraceLde, TracePolyTable, TraceTable, TraceTableFragment, TransitionConstraintDegree, + FieldExtension, LogUpGkrConstraintEvaluator, Proof, ProofOptions, Prover, ProverError, + StarkDomain, Trace, TraceInfo, TraceLde, TracePolyTable, TraceTable, TraceTableFragment, + TransitionConstraintDegree, }; pub use verifier::{verify, AcceptableOptions, ByteWriter, VerifierError}; diff --git a/winterfell/src/tests/logup_gkr_periodic.rs b/winterfell/src/tests/logup_gkr_periodic.rs index d6ae0c530..5a5b369c1 100644 --- a/winterfell/src/tests/logup_gkr_periodic.rs +++ b/winterfell/src/tests/logup_gkr_periodic.rs @@ -17,7 +17,7 @@ use crate::{ crypto::{hashers::Blake3_256, DefaultRandomCoin}, math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, matrix::ColMatrix, - DefaultConstraintEvaluator, DefaultTraceLde, Prover, StarkDomain, TracePolyTable, + DefaultTraceLde, Prover, StarkDomain, TracePolyTable, }; #[test] @@ -299,7 +299,7 @@ impl Prover for LogUpGkrPeriodicProver { type TraceLde> = DefaultTraceLde; type ConstraintEvaluator<'a, E: FieldElement> = - DefaultConstraintEvaluator<'a, LogUpGkrPeriodicAir, E>; + LogUpGkrConstraintEvaluator<'a, LogUpGkrPeriodicAir, E>; fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { } @@ -329,7 +329,7 @@ impl Prover for LogUpGkrPeriodicProver { where E: math::FieldElement, { - DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients) + LogUpGkrConstraintEvaluator::new(air, aux_rand_elements.unwrap(), composition_coefficients) } fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix diff --git a/winterfell/src/tests/logup_gkr_simple.rs b/winterfell/src/tests/logup_gkr_simple.rs index 3ffe8ea3b..406f6eac1 100644 --- a/winterfell/src/tests/logup_gkr_simple.rs +++ b/winterfell/src/tests/logup_gkr_simple.rs @@ -17,7 +17,7 @@ use crate::{ crypto::{hashers::Blake3_256, DefaultRandomCoin}, math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, matrix::ColMatrix, - DefaultConstraintEvaluator, DefaultTraceLde, Prover, StarkDomain, TracePolyTable, + DefaultTraceLde, Prover, StarkDomain, TracePolyTable, }; #[test] @@ -285,7 +285,7 @@ impl Prover for LogUpGkrSimpleProver { type TraceLde> = DefaultTraceLde; type ConstraintEvaluator<'a, E: FieldElement> = - DefaultConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; + LogUpGkrConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { } @@ -315,7 +315,7 @@ impl Prover for LogUpGkrSimpleProver { where E: math::FieldElement, { - DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients) + LogUpGkrConstraintEvaluator::new(air, aux_rand_elements.unwrap(), composition_coefficients) } fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix From 8b95c8a10f05686db785e67ed0d3b467611e7066 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philippe=20Laferri=C3=A8re?= Date: Wed, 18 Sep 2024 02:33:23 -0400 Subject: [PATCH 11/19] Optimize `bind_least_significant_variable` (#319) --- sumcheck/benches/bind_variable.rs | 3 +-- sumcheck/src/multilinear.rs | 24 +++++++++++++++++++----- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/sumcheck/benches/bind_variable.rs b/sumcheck/benches/bind_variable.rs index 07bbcc18d..4e65f684b 100644 --- a/sumcheck/benches/bind_variable.rs +++ b/sumcheck/benches/bind_variable.rs @@ -16,8 +16,7 @@ const POLY_SIZE: [usize; 2] = [1 << 18, 1 << 20]; fn bind_variable(c: &mut Criterion) { let mut group = c.benchmark_group("bind variable "); - group.sample_size(10); - group.measurement_time(Duration::from_secs(10)); + group.measurement_time(Duration::from_secs(15)); for &poly_size in POLY_SIZE.iter() { group.bench_function(BenchmarkId::new("", poly_size), |b| { diff --git a/sumcheck/src/multilinear.rs b/sumcheck/src/multilinear.rs index 110ef1fa7..df6177914 100644 --- a/sumcheck/src/multilinear.rs +++ b/sumcheck/src/multilinear.rs @@ -70,23 +70,37 @@ impl MultiLinearPoly { /// Computes $f(r_0, y_1, ..., y_{{\nu} - 1})$ using the linear interpolation formula /// $(1 - r_0) * f(0, y_1, ..., y_{{\nu} - 1}) + r_0 * f(1, y_1, ..., y_{{\nu} - 1})$ and assigns /// the resulting multi-linear, defined over a domain of half the size, to `self`. + #[inline(always)] pub fn bind_least_significant_variable(&mut self, round_challenge: E) { let num_evals = self.evaluations.len() >> 1; #[cfg(not(feature = "concurrent"))] { for i in 0..num_evals { - self.evaluations[i] = self.evaluations[i << 1] - + round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]); + // SAFETY: This loops over [0, evaluations.len()/2). The largest value for `i` is + // `(evaluations.len() / 2) - 1`. Hence, the largest value for `(i<<1)` is + // `evaluations.len() - 2`, and largest value for `(i<<1) + 1` is `evaluations.len() - 1`. + let evaluations_2i = unsafe { *self.evaluations.get_unchecked(i << 1) }; + let evaluations_2i_plus_1 = + unsafe { *self.evaluations.get_unchecked((i << 1) + 1) }; + + self.evaluations[i] = + evaluations_2i + round_challenge * (evaluations_2i_plus_1 - evaluations_2i); } - self.evaluations.truncate(num_evals) + self.evaluations.truncate(num_evals); } #[cfg(feature = "concurrent")] { let mut result = unsafe { utils::uninit_vector(num_evals) }; result.par_iter_mut().enumerate().for_each(|(i, ev)| { - *ev = self.evaluations[i << 1] - + round_challenge * (self.evaluations[(i << 1) + 1] - self.evaluations[i << 1]) + // SAFETY: This loops over [0, evaluations.len()/2). The largest value for `i` is + // `(evaluations.len() / 2) - 1`. Hence, the largest value for `(i<<1)` is + // `evaluations.len() - 2`, and largest value for `(i<<1) + 1` is `evaluations.len() - 1`. + let evaluations_2i = unsafe { *self.evaluations.get_unchecked(i << 1) }; + let evaluations_2i_plus_1 = + unsafe { *self.evaluations.get_unchecked((i << 1) + 1) }; + + *ev = evaluations_2i + round_challenge * (evaluations_2i_plus_1 - evaluations_2i); }); self.evaluations = result } From 8f08bd03581232277ed8c7830c9dd108ad12532e Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 18 Sep 2024 19:44:37 +0200 Subject: [PATCH 12/19] Add instrumentation (#321) --- prover/src/logup_gkr/mod.rs | 2 ++ prover/src/logup_gkr/prover.rs | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 2c4846369..474d8fd7e 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -4,6 +4,7 @@ use core::ops::Add; use air::{EvaluationFrame, GkrData, LogUpGkrEvaluator}; use math::FieldElement; use sumcheck::{EqFunction, MultiLinearPoly, SumCheckProverError}; +use tracing::instrument; use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::Trace; @@ -56,6 +57,7 @@ pub struct EvaluatedCircuit { impl EvaluatedCircuit { /// Creates a new [`EvaluatedCircuit`] by evaluating the circuit where the input layer is /// defined from the main trace columns. + #[instrument(skip_all, name = "evaluate_logup_gkr_circuit")] pub fn new( main_trace_columns: &impl Trace, evaluator: &impl LogUpGkrEvaluator, diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index f1a66cf35..111ac374f 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -7,6 +7,7 @@ use sumcheck::{ sum_check_prove_higher_degree, sumcheck_prove_plain, BeforeFinalLayerProof, CircuitOutput, EqFunction, FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, }; +use tracing::instrument; use super::{reduce_layer_claim, CircuitLayerPolys, EvaluatedCircuit, GkrClaim, GkrProverError}; use crate::{matrix::ColMatrix, Trace}; @@ -53,6 +54,7 @@ use crate::{matrix::ColMatrix, Trace}; /// As part of the final sum-check protocol, the openings {f_j(ρ)} are provided as part of a /// [`FinalOpeningClaim`]. This latter claim will be proven by the STARK prover later on using the /// auxiliary trace. +#[instrument(skip_all)] pub fn prove_gkr( main_trace: &impl Trace, evaluator: &impl LogUpGkrEvaluator, @@ -98,6 +100,7 @@ pub fn prove_gkr( } /// Proves the final GKR layer which corresponds to the input circuit layer. +#[instrument(skip_all)] fn prove_input_layer< E: FieldElement, C: RandomCoin, @@ -133,6 +136,7 @@ fn prove_input_layer< /// Builds the multi-linear extension polynomials needed to run the final sum-check of GKR for /// LogUp-GKR. +#[instrument(skip_all)] fn build_mls_from_main_trace_segment( oracles: &[LogUpGkrOracle], main_trace: &ColMatrix<::BaseField>, @@ -163,6 +167,7 @@ fn build_mls_from_main_trace_segment( } /// Proves all GKR layers except for input layer. +#[instrument(skip_all)] fn prove_intermediate_layers< E: FieldElement, C: RandomCoin, From ccc8819fdab27fab3e72299f655ec65a0daa4c9e Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Wed, 18 Sep 2024 20:12:39 +0200 Subject: [PATCH 13/19] fix: add instrumentation to aux segment (#322) --- prover/src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/prover/src/lib.rs b/prover/src/lib.rs index e15028171..9bcc566b6 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -205,6 +205,7 @@ pub trait Prover { /// Builds and returns the auxiliary trace. #[allow(unused_variables)] #[maybe_async] + #[instrument(skip_all)] fn build_aux_trace(&self, main_trace: &Self::Trace, aux_rand_elements: &[E]) -> ColMatrix where E: FieldElement, @@ -616,6 +617,7 @@ pub trait Prover { /// /// [1]: https://eprint.iacr.org/2023/1284 #[maybe_async] +#[instrument(skip_all)] fn build_logup_gkr_columns( air: &A, main_trace: &T, From 3d095e56f8559dd7a57c3d60c7319801d88a4429 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philippe=20Laferri=C3=A8re?= Date: Fri, 20 Sep 2024 12:37:08 -0400 Subject: [PATCH 14/19] fix: fixes multilinear built from a `NextRow` oracle (#327) --- prover/src/logup_gkr/prover.rs | 3 --- winterfell/src/tests/logup_gkr_simple.rs | 32 +++++++++++++++++++++--- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 111ac374f..f258d0845 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -154,9 +154,6 @@ fn build_mls_from_main_trace_segment( LogUpGkrOracle::NextRow(index) => { let col = main_trace.get_column(*index); let mut values: Vec = col.iter().map(|value| E::from(*value)).collect(); - if let Some(value) = values.last_mut() { - *value = E::ZERO - } values.rotate_left(1); let ml = MultiLinearPoly::from_evaluations(values); mls.push(ml) diff --git a/winterfell/src/tests/logup_gkr_simple.rs b/winterfell/src/tests/logup_gkr_simple.rs index 406f6eac1..6c814c948 100644 --- a/winterfell/src/tests/logup_gkr_simple.rs +++ b/winterfell/src/tests/logup_gkr_simple.rs @@ -193,7 +193,23 @@ impl PlainLogUpGkrEval { let committed_2 = LogUpGkrOracle::CurrentRow(2); let committed_3 = LogUpGkrOracle::CurrentRow(3); let committed_4 = LogUpGkrOracle::CurrentRow(4); - let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; + let committed_0_next_row = LogUpGkrOracle::NextRow(0); + let committed_1_next_row = LogUpGkrOracle::NextRow(1); + let committed_2_next_row = LogUpGkrOracle::NextRow(2); + let committed_3_next_row = LogUpGkrOracle::NextRow(3); + let committed_4_next_row = LogUpGkrOracle::NextRow(4); + let oracles = vec![ + committed_0, + committed_1, + committed_2, + committed_3, + committed_4, + committed_0_next_row, + committed_1_next_row, + committed_2_next_row, + committed_3_next_row, + committed_4_next_row, + ]; Self { oracles, _field: PhantomData } } } @@ -223,7 +239,17 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { where E: FieldElement, { - query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f) + query[0] = frame.current()[0]; + query[1] = frame.current()[1]; + query[2] = frame.current()[2]; + query[3] = frame.current()[3]; + query[4] = frame.current()[4]; + + query[5] = frame.next()[0]; + query[6] = frame.next()[1]; + query[7] = frame.next()[2]; + query[8] = frame.next()[3]; + query[9] = frame.next()[4]; } fn evaluate_query( @@ -239,7 +265,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { { assert_eq!(numerator.len(), 4); assert_eq!(denominator.len(), 4); - assert_eq!(query.len(), 5); + assert_eq!(query.len(), 10); numerator[0] = E::from(query[1]); numerator[1] = E::ONE; numerator[2] = E::ONE; From 09ed09b055f10fab37345ebef8263be2b9eef682 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Fri, 20 Sep 2024 19:43:12 +0200 Subject: [PATCH 15/19] Parallelize input layer generation (#324) --- prover/src/logup_gkr/mod.rs | 88 +++++++++++++--------- utils/core/src/iterators.rs | 18 +++++ winterfell/src/tests/logup_gkr_periodic.rs | 2 +- 3 files changed, 72 insertions(+), 36 deletions(-) diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 474d8fd7e..016cd5218 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -5,12 +5,17 @@ use air::{EvaluationFrame, GkrData, LogUpGkrEvaluator}; use math::FieldElement; use sumcheck::{EqFunction, MultiLinearPoly, SumCheckProverError}; use tracing::instrument; -use utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; +use utils::{ + batch_iter_mut, chunks, uninit_vector, ByteReader, ByteWriter, Deserializable, + DeserializationError, Serializable, +}; use crate::Trace; mod prover; pub use prover::prove_gkr; +#[cfg(feature = "concurrent")] +pub use utils::rayon::{current_num_threads as rayon_num_threads, prelude::*}; // EVALUATED CIRCUIT // ================================================================================================ @@ -106,7 +111,7 @@ impl EvaluatedCircuit { /// Generates the input layer of the circuit from the main trace columns and some randomness /// provided by the verifier. fn generate_input_layer( - main_trace: &impl Trace, + trace: &impl Trace, evaluator: &impl LogUpGkrEvaluator, log_up_randomness: &[E], ) -> CircuitLayer { @@ -114,45 +119,58 @@ impl EvaluatedCircuit { let periodic_values = evaluator.build_periodic_values(); let mut input_layer_wires = - Vec::with_capacity(main_trace.main_segment().num_rows() * num_fractions); - let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols()); - - let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; - let mut periodic_values_row = vec![E::BaseField::ZERO; periodic_values.num_columns()]; - let mut numerators = vec![E::ZERO; num_fractions]; - let mut denominators = vec![E::ZERO; num_fractions]; - for i in 0..main_trace.main_segment().num_rows() { - let wires_from_trace_row = { - main_trace.read_main_frame(i, &mut main_frame); - periodic_values.fill_periodic_values_at(i, &mut periodic_values_row); - evaluator.build_query(&main_frame, &mut query); - - evaluator.evaluate_query( - &query, - &periodic_values_row, - log_up_randomness, - &mut numerators, - &mut denominators, - ); - let input_gates_values: Vec> = numerators - .iter() - .zip(denominators.iter()) - .map(|(numerator, denominator)| CircuitWire::new(*numerator, *denominator)) - .collect(); - input_gates_values - }; - - input_layer_wires.extend(wires_from_trace_row); - } + unsafe { uninit_vector(trace.main_segment().num_rows() * num_fractions) }; + let num_cols = trace.main_segment().num_cols(); + let num_oracles = evaluator.get_oracles().len(); + let num_periodic_cols = periodic_values.num_columns(); + + batch_iter_mut!( + &mut input_layer_wires, + 1024, + |batch: &mut [CircuitWire], batch_offset: usize| { + let mut main_frame = EvaluationFrame::new(num_cols); + let mut query = vec![E::BaseField::ZERO; num_oracles]; + let mut periodic_values_row = vec![E::BaseField::ZERO; num_periodic_cols]; + let mut numerators = vec![E::ZERO; num_fractions]; + let mut denominators = vec![E::ZERO; num_fractions]; + + let row_offset = batch_offset / num_fractions; + let batch_size = batch.len(); + let num_rows_per_batch = batch_size / num_fractions; + + for i in + (0..trace.main_segment().num_rows()).skip(row_offset).take(num_rows_per_batch) + { + trace.read_main_frame(i, &mut main_frame); + periodic_values.fill_periodic_values_at(i, &mut periodic_values_row); + evaluator.build_query(&main_frame, &mut query); + + evaluator.evaluate_query( + &query, + &periodic_values_row, + log_up_randomness, + &mut numerators, + &mut denominators, + ); + + let n = (i - row_offset) * num_fractions; + for ((wire, numerator), denominator) in batch[n..n + num_fractions] + .iter_mut() + .zip(numerators.iter()) + .zip(denominators.iter()) + { + *wire = CircuitWire::new(*numerator, *denominator); + } + } + } + ); CircuitLayer::new(input_layer_wires) } /// Computes the subsequent layer of the circuit from a given layer. fn compute_next_layer(prev_layer: &CircuitLayer) -> CircuitLayer { - let next_layer_wires = prev_layer - .wires() - .chunks_exact(2) + let next_layer_wires = chunks!(prev_layer.wires(), 2) .map(|input_wires| { let left_input_wire = input_wires[0]; let right_input_wire = input_wires[1]; diff --git a/utils/core/src/iterators.rs b/utils/core/src/iterators.rs index 2d9782730..f978acd40 100644 --- a/utils/core/src/iterators.rs +++ b/utils/core/src/iterators.rs @@ -115,3 +115,21 @@ macro_rules! batch_iter_mut { $c($e, 0); }; } + +/// Returns either a regular or a parallel iterator over at most `chunk_size` elements depending +/// on whether `concurrent` feature is enabled. +/// +/// When `concurrent` feature is enabled, creates a parallel iterator; otherwise, creates a +/// regular iterator. +#[macro_export] +macro_rules! chunks { + ($e: expr, $chunk_size: expr) => {{ + #[cfg(feature = "concurrent")] + let result = $e.par_chunks($chunk_size); + + #[cfg(not(feature = "concurrent"))] + let result = $e.chunks($chunk_size); + + result + }}; +} diff --git a/winterfell/src/tests/logup_gkr_periodic.rs b/winterfell/src/tests/logup_gkr_periodic.rs index 5a5b369c1..849cbbd5d 100644 --- a/winterfell/src/tests/logup_gkr_periodic.rs +++ b/winterfell/src/tests/logup_gkr_periodic.rs @@ -23,7 +23,7 @@ use crate::{ #[test] fn test_logup_gkr_periodic() { let aux_trace_width = 1; - let trace = LogUpGkrPeriodic::new(2_usize.pow(7), aux_trace_width); + let trace = LogUpGkrPeriodic::new(2_usize.pow(12), aux_trace_width); let prover = LogUpGkrPeriodicProver::new(aux_trace_width); let proof = prover.prove(trace).unwrap(); From a472fa2357d4b2f4f187cee4489d589d4969ee76 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Thu, 26 Sep 2024 20:47:01 +0200 Subject: [PATCH 16/19] Fix bug parallel execution of input layer proving (#331) --- prover/Cargo.toml | 6 +- prover/benches/logup_gkr.rs | 121 ++------- prover/benches/logup_gkr_e2e.rs | 368 ++++++++++++++++++++++++++++ prover/src/lib.rs | 2 +- sumcheck/benches/sum_check_plain.rs | 2 - sumcheck/src/prover/high_degree.rs | 12 +- 6 files changed, 400 insertions(+), 111 deletions(-) create mode 100644 prover/benches/logup_gkr_e2e.rs diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 37e45c472..f75acdc8e 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -19,6 +19,10 @@ bench = false name = "logup_gkr" harness = false +[[bench]] +name = "logup_gkr_e2e" +harness = false + [[bench]] name = "row_matrix" harness = false @@ -29,7 +33,7 @@ harness = false [features] async = ["maybe_async/async"] -concurrent = ["crypto/concurrent", "math/concurrent", "fri/concurrent", "utils/concurrent", "std"] +concurrent = ["crypto/concurrent", "math/concurrent", "fri/concurrent", "utils/concurrent", "sumcheck/concurrent", "std"] default = ["std"] std = ["air/std", "crypto/std", "fri/std", "math/std", "utils/std"] diff --git a/prover/benches/logup_gkr.rs b/prover/benches/logup_gkr.rs index 6e67eddc2..e86c84aef 100644 --- a/prover/benches/logup_gkr.rs +++ b/prover/benches/logup_gkr.rs @@ -6,24 +6,22 @@ use std::{marker::PhantomData, time::Duration, vec::Vec}; use air::{ - Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients, - EvaluationFrame, FieldExtension, LogUpGkrEvaluator, LogUpGkrOracle, ProofOptions, TraceInfo, - TransitionConstraintDegree, + Air, AirContext, Assertion, AuxRandElements, EvaluationFrame, LogUpGkrEvaluator, + LogUpGkrOracle, ProofOptions, TraceInfo, TransitionConstraintDegree, }; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; -use crypto::MerkleTree; +use crypto::RandomCoin; use math::StarkField; use winter_prover::{ crypto::{hashers::Blake3_256, DefaultRandomCoin}, math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, matrix::ColMatrix, - DefaultTraceLde, LogUpGkrConstraintEvaluator, Prover, StarkDomain, Trace, TracePolyTable, + prove_gkr, Trace, }; -const TRACE_LENS: [usize; 2] = [2_usize.pow(18), 2_usize.pow(20)]; -const AUX_TRACE_WIDTH: usize = 2; +const TRACE_LENS: [usize; 4] = [2_usize.pow(18), 2_usize.pow(19), 2_usize.pow(20), 2_usize.pow(21)]; -/// Simple end-to-end benchmark for LogUp-GKR. +/// Simple benchmark for the GKR part of STARK with LogUp-GKR. /// /// The main trace contains `5` columns and the LogUp relation is a simple one where we have: /// @@ -33,27 +31,31 @@ const AUX_TRACE_WIDTH: usize = 2; /// /// Given the above, the benchmark then gives an idea about the minimal overhead due to enabling /// LogUp-GKR. The overhead could be bigger depending on the complexity of the LogUp relation. -fn prove_with_lagrange_kernel(c: &mut Criterion) { - let mut group = c.benchmark_group("prove with Lagrange kernel column"); +fn prove_with_logup_gkr(c: &mut Criterion) { + let mut group = c.benchmark_group("prove LogUp-GKR"); group.sample_size(10); group.measurement_time(Duration::from_secs(20)); for &trace_len in TRACE_LENS.iter() { group.bench_function(BenchmarkId::new("", trace_len), |b| { - let trace = LogUpGkrSimpleTrace::new(trace_len, AUX_TRACE_WIDTH); - let prover = LogUpGkrSimpleProver::new(AUX_TRACE_WIDTH); + let main_trace = LogUpGkrSimpleTrace::new(trace_len); + let evaluator = PlainLogUpGkrEval::new(); b.iter_batched( - || trace.clone(), - |trace| prover.prove(trace).unwrap(), + || (main_trace.clone(), evaluator.clone()), + |(main_trace, evaluator)| { + let mut public_coin = + DefaultRandomCoin::>::new(&[BaseElement::ZERO; 4]); + prove_gkr::(&main_trace, &evaluator, &mut public_coin) + }, BatchSize::SmallInput, ) }); } } -criterion_group!(lagrange_kernel_group, prove_with_lagrange_kernel); -criterion_main!(lagrange_kernel_group); +criterion_group!(logup_gkr_group, prove_with_logup_gkr); +criterion_main!(logup_gkr_group); // LogUpGkrSimple // ================================================================================================= @@ -66,7 +68,7 @@ struct LogUpGkrSimpleTrace { } impl LogUpGkrSimpleTrace { - fn new(trace_len: usize, aux_segment_width: usize) -> Self { + fn new(trace_len: usize) -> Self { assert!(trace_len < u32::MAX.try_into().unwrap()); // we create a column for the table we are looking values into. These are just the integers @@ -103,7 +105,7 @@ impl LogUpGkrSimpleTrace { Self { main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), - info: TraceInfo::new_multi_segment(5, aux_segment_width, 0, trace_len, vec![], true), + info: TraceInfo::new_multi_segment(5, 0, 0, trace_len, vec![], true), } } @@ -283,86 +285,3 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { E::ZERO } } -// Prover -// ================================================================================================ - -struct LogUpGkrSimpleProver { - aux_trace_width: usize, - options: ProofOptions, -} - -impl LogUpGkrSimpleProver { - fn new(aux_trace_width: usize) -> Self { - Self { - aux_trace_width, - options: ProofOptions::new(1, 8, 0, FieldExtension::Quadratic, 2, 1), - } - } -} - -impl Prover for LogUpGkrSimpleProver { - type BaseField = BaseElement; - type Air = LogUpGkrSimpleAir; - type Trace = LogUpGkrSimpleTrace; - type HashFn = Blake3_256; - type VC = MerkleTree>; - type RandomCoin = DefaultRandomCoin; - type TraceLde> = - DefaultTraceLde; - type ConstraintEvaluator<'a, E: FieldElement> = - LogUpGkrConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; - - fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { - } - - fn options(&self) -> &ProofOptions { - &self.options - } - - fn new_trace_lde( - &self, - trace_info: &TraceInfo, - main_trace: &ColMatrix, - domain: &StarkDomain, - ) -> (Self::TraceLde, TracePolyTable) - where - E: math::FieldElement, - { - DefaultTraceLde::new(trace_info, main_trace, domain) - } - - fn new_evaluator<'a, E>( - &self, - air: &'a Self::Air, - aux_rand_elements: Option>, - composition_coefficients: ConstraintCompositionCoefficients, - ) -> Self::ConstraintEvaluator<'a, E> - where - E: math::FieldElement, - { - LogUpGkrConstraintEvaluator::new(air, aux_rand_elements.unwrap(), composition_coefficients) - } - - fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix - where - E: FieldElement, - { - let main_trace = main_trace.main_segment(); - - let mut columns = Vec::new(); - - let rand_summed = E::from(777_u32); - for _ in 0..self.aux_trace_width { - // building a dummy auxiliary column - let column = main_trace - .get_column(0) - .iter() - .map(|row_val| rand_summed.mul_base(*row_val)) - .collect(); - - columns.push(column); - } - - ColMatrix::new(columns) - } -} diff --git a/prover/benches/logup_gkr_e2e.rs b/prover/benches/logup_gkr_e2e.rs new file mode 100644 index 000000000..2f81dd850 --- /dev/null +++ b/prover/benches/logup_gkr_e2e.rs @@ -0,0 +1,368 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::{marker::PhantomData, time::Duration, vec::Vec}; + +use air::{ + Air, AirContext, Assertion, AuxRandElements, ConstraintCompositionCoefficients, + EvaluationFrame, FieldExtension, LogUpGkrEvaluator, LogUpGkrOracle, ProofOptions, TraceInfo, + TransitionConstraintDegree, +}; +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use crypto::MerkleTree; +use math::StarkField; +use winter_prover::{ + crypto::{hashers::Blake3_256, DefaultRandomCoin}, + math::{fields::f64::BaseElement, ExtensionOf, FieldElement}, + matrix::ColMatrix, + DefaultTraceLde, LogUpGkrConstraintEvaluator, Prover, StarkDomain, Trace, TracePolyTable, +}; + +const TRACE_LENS: [usize; 2] = [2_usize.pow(18), 2_usize.pow(20)]; +const AUX_TRACE_WIDTH: usize = 2; + +/// Simple end-to-end benchmark for LogUp-GKR. +/// +/// The main trace contains `5` columns and the LogUp relation is a simple one where we have: +/// +/// 1. a table of values from `0` to `trace_len - 1`. +/// 2. a multiplicity column containing the number of look ups for each value in the table. +/// 3. three columns with values contained in the table above. +/// +/// Given the above, the benchmark then gives an idea about the minimal overhead due to enabling +/// LogUp-GKR. The overhead could be bigger depending on the complexity of the LogUp relation. +fn prove_with_logup_gkr(c: &mut Criterion) { + let mut group = c.benchmark_group("prove with LogUp-GKR"); + group.sample_size(10); + group.measurement_time(Duration::from_secs(20)); + + for &trace_len in TRACE_LENS.iter() { + group.bench_function(BenchmarkId::new("", trace_len), |b| { + let trace = LogUpGkrSimpleTrace::new(trace_len, AUX_TRACE_WIDTH); + let prover = LogUpGkrSimpleProver::new(AUX_TRACE_WIDTH); + + b.iter_batched( + || trace.clone(), + |trace| prover.prove(trace).unwrap(), + BatchSize::SmallInput, + ) + }); + } +} + +criterion_group!(logup_gkr_group, prove_with_logup_gkr); +criterion_main!(logup_gkr_group); + +// LogUpGkrSimple +// ================================================================================================= + +#[derive(Clone, Debug)] +struct LogUpGkrSimpleTrace { + // dummy main trace + main_trace: ColMatrix, + info: TraceInfo, +} + +impl LogUpGkrSimpleTrace { + fn new(trace_len: usize, aux_segment_width: usize) -> Self { + assert!(trace_len < u32::MAX.try_into().unwrap()); + + // we create a column for the table we are looking values into. These are just the integers + // from 0 to `trace_len`. + let table: Vec = + (0..trace_len).map(|idx| BaseElement::from(idx as u32)).collect(); + + // we create three columns that contains values contained in `table`. For simplicity, we + // look up only the values `0` or `1`, we look up the value `1` four times and the value `0` + // `trace_len - 4` times. + + let mut values_0: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_0[i + 4] = BaseElement::ONE; + } + + let mut values_1: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_1[i + 4] = BaseElement::ONE; + } + + let mut values_2: Vec = (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + for i in 0..4 { + values_2[i + 4] = BaseElement::ONE; + } + + // we create the multiplicity column + let mut multiplicity: Vec = + (0..trace_len).map(|_idx| BaseElement::ZERO).collect(); + // we look up the value `1` four times in three columns + multiplicity[1] = BaseElement::new(3 * 4); + // we look up the value `0` `trace_len - 4` in three columns + multiplicity[0] = BaseElement::new(3 * trace_len as u64 - 3 * 4); + + Self { + main_trace: ColMatrix::new(vec![table, multiplicity, values_0, values_1, values_2]), + info: TraceInfo::new_multi_segment(5, aux_segment_width, 0, trace_len, vec![], true), + } + } + + fn len(&self) -> usize { + self.main_trace.num_rows() + } +} + +impl Trace for LogUpGkrSimpleTrace { + type BaseField = BaseElement; + + fn info(&self) -> &TraceInfo { + &self.info + } + + fn main_segment(&self) -> &ColMatrix { + &self.main_trace + } + + fn read_main_frame(&self, row_idx: usize, frame: &mut EvaluationFrame) { + let next_row_idx = row_idx + 1; + self.main_trace.read_row_into(row_idx, frame.current_mut()); + self.main_trace.read_row_into(next_row_idx % self.len(), frame.next_mut()); + } +} + +// AIR +// ================================================================================================= + +struct LogUpGkrSimpleAir { + context: AirContext, +} + +impl Air for LogUpGkrSimpleAir { + type BaseField = BaseElement; + type PublicInputs = (); + + fn new(trace_info: TraceInfo, _pub_inputs: Self::PublicInputs, options: ProofOptions) -> Self { + Self { + context: AirContext::new_multi_segment( + trace_info, + _pub_inputs, + vec![TransitionConstraintDegree::new(1)], + vec![], + 1, + 0, + options, + ), + } + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn evaluate_transition>( + &self, + frame: &EvaluationFrame, + _periodic_values: &[E], + result: &mut [E], + ) { + let current = frame.current()[0]; + let next = frame.next()[0]; + + // increments by 1 + result[0] = next - current - E::ONE; + } + + fn get_assertions(&self) -> Vec> { + vec![Assertion::single(0, 0, BaseElement::ZERO)] + } + + fn evaluate_aux_transition( + &self, + _main_frame: &EvaluationFrame, + _aux_frame: &EvaluationFrame, + _periodic_values: &[F], + _aux_rand_elements: &AuxRandElements, + _result: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + // do nothing + } + + fn get_aux_assertions>( + &self, + _aux_rand_elements: &AuxRandElements, + ) -> Vec> { + vec![] + } + + fn get_logup_gkr_evaluator( + &self, + ) -> impl LogUpGkrEvaluator + { + PlainLogUpGkrEval::new() + } +} + +#[derive(Clone, Default)] +pub struct PlainLogUpGkrEval { + oracles: Vec, + _field: PhantomData, +} + +impl PlainLogUpGkrEval { + pub fn new() -> Self { + let committed_0 = LogUpGkrOracle::CurrentRow(0); + let committed_1 = LogUpGkrOracle::CurrentRow(1); + let committed_2 = LogUpGkrOracle::CurrentRow(2); + let committed_3 = LogUpGkrOracle::CurrentRow(3); + let committed_4 = LogUpGkrOracle::CurrentRow(4); + let oracles = vec![committed_0, committed_1, committed_2, committed_3, committed_4]; + Self { oracles, _field: PhantomData } + } +} + +impl LogUpGkrEvaluator for PlainLogUpGkrEval { + type BaseField = BaseElement; + + type PublicInputs = (); + + fn get_oracles(&self) -> &[LogUpGkrOracle] { + &self.oracles + } + + fn get_num_rand_values(&self) -> usize { + 1 + } + + fn get_num_fractions(&self) -> usize { + 4 + } + + fn max_degree(&self) -> usize { + 3 + } + + fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) + where + E: FieldElement, + { + query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f) + } + + fn evaluate_query( + &self, + query: &[F], + _periodic_values: &[F], + rand_values: &[E], + numerator: &mut [E], + denominator: &mut [E], + ) where + F: FieldElement, + E: FieldElement + ExtensionOf, + { + assert_eq!(numerator.len(), 4); + assert_eq!(denominator.len(), 4); + assert_eq!(query.len(), 5); + numerator[0] = E::from(query[1]); + numerator[1] = E::ONE; + numerator[2] = E::ONE; + numerator[3] = E::ONE; + + denominator[0] = rand_values[0] - E::from(query[0]); + denominator[1] = -(rand_values[0] - E::from(query[2])); + denominator[2] = -(rand_values[0] - E::from(query[3])); + denominator[3] = -(rand_values[0] - E::from(query[4])); + } + + fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E + where + E: FieldElement, + { + E::ZERO + } +} +// Prover +// ================================================================================================ + +struct LogUpGkrSimpleProver { + aux_trace_width: usize, + options: ProofOptions, +} + +impl LogUpGkrSimpleProver { + fn new(aux_trace_width: usize) -> Self { + Self { + aux_trace_width, + options: ProofOptions::new(1, 8, 0, FieldExtension::Quadratic, 2, 1), + } + } +} + +impl Prover for LogUpGkrSimpleProver { + type BaseField = BaseElement; + type Air = LogUpGkrSimpleAir; + type Trace = LogUpGkrSimpleTrace; + type HashFn = Blake3_256; + type VC = MerkleTree>; + type RandomCoin = DefaultRandomCoin; + type TraceLde> = + DefaultTraceLde; + type ConstraintEvaluator<'a, E: FieldElement> = + LogUpGkrConstraintEvaluator<'a, LogUpGkrSimpleAir, E>; + + fn get_pub_inputs(&self, _trace: &Self::Trace) -> <::Air as Air>::PublicInputs { + } + + fn options(&self) -> &ProofOptions { + &self.options + } + + fn new_trace_lde( + &self, + trace_info: &TraceInfo, + main_trace: &ColMatrix, + domain: &StarkDomain, + ) -> (Self::TraceLde, TracePolyTable) + where + E: math::FieldElement, + { + DefaultTraceLde::new(trace_info, main_trace, domain) + } + + fn new_evaluator<'a, E>( + &self, + air: &'a Self::Air, + aux_rand_elements: Option>, + composition_coefficients: ConstraintCompositionCoefficients, + ) -> Self::ConstraintEvaluator<'a, E> + where + E: math::FieldElement, + { + LogUpGkrConstraintEvaluator::new(air, aux_rand_elements.unwrap(), composition_coefficients) + } + + fn build_aux_trace(&self, main_trace: &Self::Trace, _aux_rand_elements: &[E]) -> ColMatrix + where + E: FieldElement, + { + let main_trace = main_trace.main_segment(); + + let mut columns = Vec::new(); + + let rand_summed = E::from(777_u32); + for _ in 0..self.aux_trace_width { + // building a dummy auxiliary column + let column = main_trace + .get_column(0) + .iter() + .map(|row_val| rand_summed.mul_base(*row_val)) + .collect(); + + columns.push(column); + } + + ColMatrix::new(columns) + } +} diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 9bcc566b6..d9da99dd5 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -88,7 +88,7 @@ pub use trace::{ }; mod logup_gkr; -use logup_gkr::{build_lagrange_column, build_s_column, prove_gkr}; +pub use logup_gkr::{build_lagrange_column, build_s_column, prove_gkr}; mod channel; use channel::ProverChannel; diff --git a/sumcheck/benches/sum_check_plain.rs b/sumcheck/benches/sum_check_plain.rs index 14fd859ce..6e15603ef 100644 --- a/sumcheck/benches/sum_check_plain.rs +++ b/sumcheck/benches/sum_check_plain.rs @@ -12,7 +12,6 @@ use rand_utils::{rand_value, rand_vector}; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; use winter_sumcheck::{sumcheck_prove_plain, EqFunction, MultiLinearPoly}; - const LOG_POLY_SIZE: [usize; 2] = [18, 20]; fn sum_check_plain(c: &mut Criterion) { @@ -31,7 +30,6 @@ fn sum_check_plain(c: &mut Criterion) { |((claim, r_batch, p, q, eq), transcript)| { let mut eq = eq; let mut transcript = transcript; - sumcheck_prove_plain(claim, r_batch, p, q, &mut eq, &mut transcript) }, BatchSize::SmallInput, diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index 47be290d7..0c5c0aff7 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -417,11 +417,11 @@ fn sumcheck_round( vec![E::ZERO; num_periodic], vec![E::ZERO; num_periodic], vec![E::ZERO; num_periodic], - vec![E::ZERO; evaluator.max_degree()], vec![E::ZERO; evaluator.get_num_fractions()], vec![E::ZERO; evaluator.get_num_fractions()], vec![E::ZERO; num_mls], vec![E::ZERO; num_periodic], + vec![E::ZERO; evaluator.max_degree()], ) }, |( @@ -431,11 +431,11 @@ fn sumcheck_round( mut evals_periodic_zero, mut evals_periodic_one, mut evals_periodic_x, - mut poly_evals, mut numerators, mut denominators, mut deltas, mut deltas_periodic, + mut poly_evals, ), i| { for (j, ml) in mls.iter().enumerate() { @@ -458,7 +458,7 @@ fn sumcheck_round( &mut numerators, &mut denominators, ); - poly_evals[0] = evaluate_composition_poly( + poly_evals[0] += evaluate_composition_poly( eq_mu, &numerators, &denominators, @@ -496,7 +496,7 @@ fn sumcheck_round( &mut numerators, &mut denominators, ); - *e = evaluate_composition_poly( + *e += evaluate_composition_poly( eq_mu, &numerators, &denominators, @@ -512,15 +512,15 @@ fn sumcheck_round( evals_periodic_zero, evals_periodic_one, evals_periodic_x, - poly_evals, numerators, denominators, deltas, deltas_periodic, + poly_evals, ) }, ) - .map(|(_, _, _, poly_evals, ..)| poly_evals) + .map(|(.., poly_evals)| poly_evals) .reduce( || vec![E::ZERO; evaluator.max_degree()], |mut acc, poly_eval| { From a4e383e74cf68f9792d5f704556c218b103e4652 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Fri, 27 Sep 2024 09:56:56 +0200 Subject: [PATCH 17/19] Parallelize s-column generation (#326) --- air/src/air/context.rs | 4 +- crypto/src/hash/mds/mds_f64_12x12.rs | 27 +++--- crypto/src/hash/mds/mds_f64_8x8.rs | 27 +++--- prover/src/logup_gkr/mod.rs | 119 +++++++++++++++++++++++---- prover/src/logup_gkr/prover.rs | 5 +- utils/core/src/iterators.rs | 18 ++++ 6 files changed, 154 insertions(+), 46 deletions(-) diff --git a/air/src/air/context.rs b/air/src/air/context.rs index 6eb035af9..a4074036a 100644 --- a/air/src/air/context.rs +++ b/air/src/air/context.rs @@ -305,10 +305,8 @@ impl AirContext { let trace_length = self.trace_len(); let transition_divisior_degree = trace_length - self.num_transition_exemptions(); - // we use the identity: ceil(a/b) = (a + b - 1)/b let num_constraint_col = - (highest_constraint_degree - transition_divisior_degree + trace_length - 1) - / trace_length; + (highest_constraint_degree - transition_divisior_degree).div_ceil(trace_length); cmp::max(num_constraint_col, 1) } diff --git a/crypto/src/hash/mds/mds_f64_12x12.rs b/crypto/src/hash/mds/mds_f64_12x12.rs index 44f5660b9..ddf79f4a2 100644 --- a/crypto/src/hash/mds/mds_f64_12x12.rs +++ b/crypto/src/hash/mds/mds_f64_12x12.rs @@ -12,19 +12,19 @@ use math::{ FieldElement, }; -/// This module contains helper functions as well as constants used to perform a 12x12 vector-matrix -/// multiplication. The special form of our MDS matrix i.e. being circulant, allows us to reduce -/// the vector-matrix multiplication to a Hadamard product of two vectors in "frequency domain". -/// This follows from the simple fact that every circulant matrix has the columns of the discrete -/// Fourier transform matrix as orthogonal eigenvectors. -/// The implementation also avoids the use of 3-point FFTs, and 3-point iFFTs, and substitutes that -/// with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain, -/// divisions by 2 and repeated modular reductions. This is because of our explicit choice of -/// an MDS matrix that has small powers of 2 entries in frequency domain. -/// The following implementation has benefited greatly from the discussions and insights of -/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is based on Nabaglo's implementation -/// in [Plonky2](https://github.com/mir-protocol/plonky2). -/// The circulant matrix is identified by its first row: [7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8]. +// This module contains helper functions as well as constants used to perform a 12x12 vector-matrix +// multiplication. The special form of our MDS matrix i.e. being circulant, allows us to reduce +// the vector-matrix multiplication to a Hadamard product of two vectors in "frequency domain". +// This follows from the simple fact that every circulant matrix has the columns of the discrete +// Fourier transform matrix as orthogonal eigenvectors. +// The implementation also avoids the use of 3-point FFTs, and 3-point iFFTs, and substitutes that +// with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain, +// divisions by 2 and repeated modular reductions. This is because of our explicit choice of +// an MDS matrix that has small powers of 2 entries in frequency domain. +// The following implementation has benefited greatly from the discussions and insights of +// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is based on Nabaglo's implementation +// in [Plonky2](https://github.com/mir-protocol/plonky2). +// The circulant matrix is identified by its first row: [7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8]. // MDS matrix in frequency domain. // More precisely, this is the output of the three 4-point (real) FFTs of the first column of @@ -33,6 +33,7 @@ use math::{ // The entries have been scaled appropriately in order to avoid divisions by 2 in iFFT2 and iFFT4. // The code to generate the matrix in frequency domain is based on an adaptation of a code, to generate // MDS matrices efficiently in original domain, that was developed by the Polygon Zero team. + const MDS_FREQ_BLOCK_ONE: [i64; 3] = [16, 8, 16]; const MDS_FREQ_BLOCK_TWO: [(i64, i64); 3] = [(-1, 2), (-1, 1), (4, 8)]; const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1]; diff --git a/crypto/src/hash/mds/mds_f64_8x8.rs b/crypto/src/hash/mds/mds_f64_8x8.rs index 037dee721..4e7818357 100644 --- a/crypto/src/hash/mds/mds_f64_8x8.rs +++ b/crypto/src/hash/mds/mds_f64_8x8.rs @@ -12,25 +12,26 @@ use math::{ FieldElement, }; -/// This module contains helper functions as well as constants used to perform a 8x8 vector-matrix -/// multiplication. The special form of our MDS matrix i.e. being circulant, allows us to reduce -/// the vector-matrix multiplication to a Hadamard product of two vectors in "frequency domain". -/// This follows from the simple fact that every circulant matrix has the columns of the discrete -/// Fourier transform matrix as orthogonal eigenvectors. -/// The implementation also avoids the use of internal 2-point FFTs, and 2-point iFFTs, and substitutes -/// them with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain, -/// divisions by 2 and repeated modular reductions. This is because of our explicit choice of -/// an MDS matrix that has small powers of 2 entries in frequency domain. -/// The following implementation has benefited greatly from the discussions and insights of -/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero is based on Nabaglo's implementation -/// in [Plonky2](https://github.com/mir-protocol/plonky2). -/// The circulant matrix is identified by its first row: [23, 8, 13, 10, 7, 6, 21, 8]. +// This module contains helper functions as well as constants used to perform a 8x8 vector-matrix +// multiplication. The special form of our MDS matrix i.e. being circulant, allows us to reduce +// the vector-matrix multiplication to a Hadamard product of two vectors in "frequency domain". +// This follows from the simple fact that every circulant matrix has the columns of the discrete +// Fourier transform matrix as orthogonal eigenvectors. +// The implementation also avoids the use of internal 2-point FFTs, and 2-point iFFTs, and substitutes +// them with explicit expressions. It also avoids, due to the form of our matrix in the frequency domain, +// divisions by 2 and repeated modular reductions. This is because of our explicit choice of +// an MDS matrix that has small powers of 2 entries in frequency domain. +// The following implementation has benefited greatly from the discussions and insights of +// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero is based on Nabaglo's implementation +// in [Plonky2](https://github.com/mir-protocol/plonky2). +// The circulant matrix is identified by its first row: [23, 8, 13, 10, 7, 6, 21, 8]. // MDS matrix in frequency domain. // More precisely, this is the output of the two 4-point (real) FFTs of the first column of // the MDS matrix i.e. just before the multiplication with the appropriate twiddle factors // and application of the final four 2-point FFT in order to get the full 8-point FFT. // The entries have been scaled appropriately in order to avoid divisions by 2 in iFFT2 and iFFT4. + const MDS_FREQ_BLOCK_ONE: [i64; 2] = [16, 8]; const MDS_FREQ_BLOCK_TWO: [(i64, i64); 2] = [(8, -4), (-1, 1)]; const MDS_FREQ_BLOCK_THREE: [i64; 2] = [-1, 1]; diff --git a/prover/src/logup_gkr/mod.rs b/prover/src/logup_gkr/mod.rs index 016cd5218..4d2ee975a 100644 --- a/prover/src/logup_gkr/mod.rs +++ b/prover/src/logup_gkr/mod.rs @@ -15,7 +15,10 @@ use crate::Trace; mod prover; pub use prover::prove_gkr; #[cfg(feature = "concurrent")] -pub use utils::rayon::{current_num_threads as rayon_num_threads, prelude::*}; +pub use utils::{ + rayon::{current_num_threads as rayon_num_threads, prelude::*}, + {chunks_mut, iter, iter_mut}, +}; // EVALUATED CIRCUIT // ================================================================================================ @@ -384,31 +387,65 @@ where /// /// [1]: https://eprint.iacr.org/2023/1284 pub fn build_s_column( - main_trace: &impl Trace, + trace: &impl Trace, gkr_data: &GkrData, evaluator: &impl LogUpGkrEvaluator, lagrange_kernel_col: &[E], ) -> Vec { let c = gkr_data.compute_batched_claim(); - let main_segment = main_trace.main_segment(); - let mean = c / E::from(E::BaseField::from(main_segment.num_rows() as u32)); + let num_oracles = evaluator.get_oracles().len(); - let mut result = Vec::with_capacity(main_segment.num_rows()); - let mut last_value = E::ZERO; - result.push(last_value); + let main_segment = trace.main_segment(); + let num_cols = main_segment.num_cols(); + let num_rows = main_segment.num_rows(); + let mean = c / E::from(E::BaseField::from(num_rows as u32)); - let mut query = vec![E::BaseField::ZERO; evaluator.get_oracles().len()]; - let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols()); + #[cfg(not(feature = "concurrent"))] + let result = { + let mut result = Vec::with_capacity(num_rows); + let mut last_value = E::ZERO; + result.push(last_value); - for (i, item) in lagrange_kernel_col.iter().enumerate().take(main_segment.num_rows() - 1) { - main_trace.read_main_frame(i, &mut main_frame); + let mut query = vec![E::BaseField::ZERO; num_oracles]; + let mut main_frame = EvaluationFrame::new(num_cols); - evaluator.build_query(&main_frame, &mut query); - let cur_value = last_value - mean + gkr_data.compute_batched_query(&query) * *item; + for (i, item) in lagrange_kernel_col.iter().enumerate().take(num_rows - 1) { + trace.read_main_frame(i, &mut main_frame); - result.push(cur_value); - last_value = cur_value; - } + evaluator.build_query(&main_frame, &mut query); + let cur_value = last_value - mean + gkr_data.compute_batched_query(&query) * *item; + + result.push(cur_value); + last_value = cur_value; + } + + result + }; + + #[cfg(feature = "concurrent")] + let result = { + let mut deltas = unsafe { uninit_vector(num_rows) }; + deltas[0] = E::ZERO; + let batch_size = num_rows / rayon_num_threads().next_power_of_two(); + batch_iter_mut!(&mut deltas[1..], batch_size, |batch: &mut [E], batch_offset: usize| { + let mut query = vec![E::BaseField::ZERO; num_oracles]; + let mut main_frame = EvaluationFrame::::new(num_cols); + + for (i, v) in batch.iter_mut().enumerate() { + trace.read_main_frame(i + batch_offset, &mut main_frame); + + evaluator.build_query(&main_frame, &mut query); + *v = gkr_data.compute_batched_query(&query) * lagrange_kernel_col[i + batch_offset] + - mean; + } + }); + + // note that `deltas[0]` is set `0` and thus `deltas` satisfies the conditions for invoking + // the function + let mut cumulative_sum = deltas; + prefix_sum_parallel(&mut cumulative_sum, batch_size); + cumulative_sum + }; result } @@ -425,3 +462,53 @@ pub enum GkrProverError { #[error("failed to generate the random challenge")] FailedToGenerateChallenge, } + +// HELPER +// ================================================================================================= + +/// Computes the cumulative sum, also called prefix sum, of a vector of field elements using +/// parallelism, in place. +/// +/// The function divides the vector into non-overlapping segments and then computes an array of sums +/// for each segment. The function then applies the naive serial implementation to each segment and +/// uses the pre-computed sums in each segment in order to coordinate the results in the different +/// segments. +/// +/// The input vector is of the form `0 || values` where `values` are the values the cumulative sum +/// vector will be computed for, in place. +#[cfg(feature = "concurrent")] +fn prefix_sum_parallel(vector: &mut [E], batch_size: usize) { + let num_partitions = vector.len().div_ceil(batch_size); + let mut sum_per_partition = vec![E::ZERO; num_partitions]; + + chunks!(vector, batch_size) + .zip(iter_mut!(sum_per_partition)) + .for_each(|(chunk, entry)| *entry = chunk.iter().fold(E::ZERO, |acc, term| acc + *term)); + + prefix_sum_truncate_right(&mut sum_per_partition); + + chunks_mut!(vector, batch_size) + .zip(iter!(sum_per_partition)) + .for_each(|(chunk, sum_so_far)| prefix_sum_truncate_left(chunk, *sum_so_far)); +} + +/// Computes the cumulative sum of a vector but omits the final cumulative sum. +#[cfg(feature = "concurrent")] +fn prefix_sum_truncate_right(values: &mut [E]) { + let mut sum = E::ZERO; + values.iter_mut().for_each(|v| { + let tmp = *v; + *v = sum; + sum += tmp; + }); +} + +/// Computes the cumulative sum of a vector but omits the initial cumulative sum, namely zero. +#[cfg(feature = "concurrent")] +fn prefix_sum_truncate_left(values: &mut [E], sum: E) { + let mut sum = sum; + values.iter_mut().for_each(|v| { + sum += *v; + *v = sum; + }); +} diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index f258d0845..899482f22 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -8,6 +8,9 @@ use sumcheck::{ EqFunction, FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, }; use tracing::instrument; +use utils::iter; +#[cfg(feature = "concurrent")] +pub use utils::rayon::prelude::*; use super::{reduce_layer_claim, CircuitLayerPolys, EvaluatedCircuit, GkrClaim, GkrProverError}; use crate::{matrix::ColMatrix, Trace}; @@ -147,7 +150,7 @@ fn build_mls_from_main_trace_segment( match oracle { LogUpGkrOracle::CurrentRow(index) => { let col = main_trace.get_column(*index); - let values: Vec = col.iter().map(|value| E::from(*value)).collect(); + let values: Vec = iter!(col).map(|value| E::from(*value)).collect(); let ml = MultiLinearPoly::from_evaluations(values); mls.push(ml) }, diff --git a/utils/core/src/iterators.rs b/utils/core/src/iterators.rs index f978acd40..cf8328ae9 100644 --- a/utils/core/src/iterators.rs +++ b/utils/core/src/iterators.rs @@ -133,3 +133,21 @@ macro_rules! chunks { result }}; } + +/// Returns either a regular or a parallel mutable iterator over at most `chunk_size` elements +/// depending on whether `concurrent` feature is enabled. +/// +/// When `concurrent` feature is enabled, creates a parallel iterator; otherwise, creates a +/// regular iterator. +#[macro_export] +macro_rules! chunks_mut { + ($e: expr, $chunk_size: expr) => {{ + #[cfg(feature = "concurrent")] + let result = $e.par_chunks_mut($chunk_size); + + #[cfg(not(feature = "concurrent"))] + let result = $e.chunks_mut($chunk_size); + + result + }}; +} From 536fa132d839a3b221b727d6dd823eaea5c86977 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Fri, 27 Sep 2024 18:48:33 +0200 Subject: [PATCH 18/19] Improve construction of MLEs from main trace segment (#329) --- prover/Cargo.toml | 8 ++------ prover/src/logup_gkr/prover.rs | 19 ++++++++++++------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/prover/Cargo.toml b/prover/Cargo.toml index f75acdc8e..125ab2a4f 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -15,20 +15,16 @@ rust-version = "1.78" [lib] bench = false -[[bench]] -name = "logup_gkr" -harness = false - [[bench]] name = "logup_gkr_e2e" harness = false [[bench]] -name = "row_matrix" +name = "logup_gkr" harness = false [[bench]] -name = "lagrange_kernel" +name = "row_matrix" harness = false [features] diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 899482f22..0da61fbd8 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -8,9 +8,9 @@ use sumcheck::{ EqFunction, FinalLayerProof, GkrCircuitProof, MultiLinearPoly, SumCheckProof, }; use tracing::instrument; -use utils::iter; #[cfg(feature = "concurrent")] -pub use utils::rayon::prelude::*; +use utils::rayon::prelude::*; +use utils::{iter, iter_mut, uninit_vector}; use super::{reduce_layer_claim, CircuitLayerPolys, EvaluatedCircuit, GkrClaim, GkrProverError}; use crate::{matrix::ColMatrix, Trace}; @@ -81,7 +81,7 @@ pub fn prove_gkr( // build the MLEs of the relevant main trace columns let main_trace_mls = - build_mls_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; + build_mle_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; // build the periodic table representing periodic columns as multi-linear extensions let periodic_table = evaluator.build_periodic_values(); @@ -140,11 +140,11 @@ fn prove_input_layer< /// Builds the multi-linear extension polynomials needed to run the final sum-check of GKR for /// LogUp-GKR. #[instrument(skip_all)] -fn build_mls_from_main_trace_segment( +fn build_mle_from_main_trace_segment( oracles: &[LogUpGkrOracle], main_trace: &ColMatrix<::BaseField>, ) -> Result>, GkrProverError> { - let mut mls = vec![]; + let mut mls = Vec::with_capacity(oracles.len()); for oracle in oracles { match oracle { @@ -156,13 +156,18 @@ fn build_mls_from_main_trace_segment( }, LogUpGkrOracle::NextRow(index) => { let col = main_trace.get_column(*index); - let mut values: Vec = col.iter().map(|value| E::from(*value)).collect(); - values.rotate_left(1); + + let mut values: Vec = unsafe { uninit_vector(col.len()) }; + values[col.len() - 1] = E::from(col[0]); + iter_mut!(&mut values[..col.len() - 1]) + .enumerate() + .for_each(|(i, value)| *value = E::from(col[i + 1])); let ml = MultiLinearPoly::from_evaluations(values); mls.push(ml) }, }; } + Ok(mls) } From d34e0b1f48414bebc32e40ff7f3d0bd6a93faea7 Mon Sep 17 00:00:00 2001 From: Al-Kindi-0 <82364884+Al-Kindi-0@users.noreply.github.com> Date: Tue, 1 Oct 2024 19:51:34 +0200 Subject: [PATCH 19/19] Reduce degree of sum-check round polynomials (#328) --- prover/src/logup_gkr/prover.rs | 6 +- sumcheck/benches/sum_check_plain.rs | 17 +- sumcheck/src/lib.rs | 2 +- sumcheck/src/prover/high_degree.rs | 209 ++++++++++++++++++----- sumcheck/src/prover/mod.rs | 67 ++++++++ sumcheck/src/prover/plain.rs | 248 +++++++++++++++++----------- sumcheck/src/univariate.rs | 113 ++++++------- 7 files changed, 448 insertions(+), 214 deletions(-) diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index 0da61fbd8..6f413f201 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -205,13 +205,14 @@ fn prove_intermediate_layers< // reduced in terms of the input layer separately in `prove_final_circuit_layer`. for inner_layer in circuit.layers().into_iter().skip(1).rev().skip(1) { // construct the Lagrange kernel evaluated at the previous GKR round randomness - let mut eq_mle = EqFunction::ml_at(evaluation_point.into()); + let mut eq_mle = EqFunction::ml_at(evaluation_point.clone().into()); let (numerators, denominators) = inner_layer.into_numerators_denominators(); // run the sumcheck protocol let proof = sum_check_prove_num_rounds_degree_3( claimed_evaluation, + &evaluation_point, numerators, denominators, &mut eq_mle, @@ -260,6 +261,7 @@ fn sum_check_prove_num_rounds_degree_3< H: ElementHasher, >( claim: (E, E), + evaluation_point: &[E], p: MultiLinearPoly, q: MultiLinearPoly, eq: &mut MultiLinearPoly, @@ -270,7 +272,7 @@ fn sum_check_prove_num_rounds_degree_3< let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; let claim = claim.0 + claim.1 * r_batch; - let proof = sumcheck_prove_plain(claim, r_batch, p, q, eq, transcript)?; + let proof = sumcheck_prove_plain(claim, evaluation_point, r_batch, p, q, eq, transcript)?; Ok(proof) } diff --git a/sumcheck/benches/sum_check_plain.rs b/sumcheck/benches/sum_check_plain.rs index 6e15603ef..203961fa4 100644 --- a/sumcheck/benches/sum_check_plain.rs +++ b/sumcheck/benches/sum_check_plain.rs @@ -27,10 +27,18 @@ fn sum_check_plain(c: &mut Criterion) { DefaultRandomCoin::>::new(&[BaseElement::ZERO; 4]); (setup_sum_check::(log_poly_size), transcript) }, - |((claim, r_batch, p, q, eq), transcript)| { + |((claim, evaluation_point, r_batch, p, q, eq), transcript)| { let mut eq = eq; let mut transcript = transcript; - sumcheck_prove_plain(claim, r_batch, p, q, &mut eq, &mut transcript) + sumcheck_prove_plain( + claim, + &evaluation_point, + r_batch, + p, + q, + &mut eq, + &mut transcript, + ) }, BatchSize::SmallInput, ) @@ -42,7 +50,7 @@ fn sum_check_plain(c: &mut Criterion) { #[allow(clippy::type_complexity)] fn setup_sum_check( log_size: usize, -) -> (E, E, MultiLinearPoly, MultiLinearPoly, MultiLinearPoly) { +) -> (E, Vec, E, MultiLinearPoly, MultiLinearPoly, MultiLinearPoly) { let n = 1 << (log_size + 1); let p: Vec = rand_vector(n); let q: Vec = rand_vector(n); @@ -52,12 +60,13 @@ fn setup_sum_check( let rand_pt = rand_vector(log_size); let r_batch: E = rand_value(); let claim: E = rand_value(); + let evaluation_point = rand_vector(log_size); let p = MultiLinearPoly::from_evaluations(p); let q = MultiLinearPoly::from_evaluations(q); let eq = MultiLinearPoly::from_evaluations(EqFunction::new(rand_pt.into()).evaluations()); - (claim, r_batch, p, q, eq) + (claim, evaluation_point, r_batch, p, q, eq) } criterion_group!(group, sum_check_plain); diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index b7f670a9d..b11f19d74 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -23,7 +23,7 @@ mod verifier; pub use verifier::*; mod univariate; -pub use univariate::{CompressedUnivariatePoly, CompressedUnivariatePolyEvals}; +pub use univariate::CompressedUnivariatePoly; mod multilinear; pub use multilinear::{inner_product, EqFunction, MultiLinearPoly}; diff --git a/sumcheck/src/prover/high_degree.rs b/sumcheck/src/prover/high_degree.rs index 0c5c0aff7..da2021f63 100644 --- a/sumcheck/src/prover/high_degree.rs +++ b/sumcheck/src/prover/high_degree.rs @@ -11,10 +11,10 @@ use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; -use super::SumCheckProverError; +use super::{compute_scaling_down_factors, to_coefficients, SumCheckProverError}; use crate::{ - evaluate_composition_poly, CompressedUnivariatePolyEvals, EqFunction, FinalOpeningClaim, - MultiLinearPoly, RoundProof, SumCheckProof, SumCheckRoundClaim, + evaluate_composition_poly, EqFunction, FinalOpeningClaim, MultiLinearPoly, RoundProof, + SumCheckProof, SumCheckRoundClaim, }; /// A sum-check prover for the input layer which can accommodate non-linear expressions in @@ -148,7 +148,102 @@ use crate::{ /// \right) /// $$ /// +/// +/// We now discuss a further optimization due to [2]. Suppose that we have a sum-check statment of +/// the following form: +/// +/// $$v_0=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{\nu - 1}\right);\left( x_0, \cdots, x_{\nu - 1}\right)\right) +/// C\left( x_0, \cdots, x_{\nu - 1} \right)$$ +/// +/// Then during round $i + 1$ of sum-check, the prover needs to send the following polynomial +/// +/// $$v_{i+1}(X)=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1},\alpha_i, \alpha_{i+1},\cdots\alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// We can write $v_{i+1}(X)$ as: +/// +/// $$v_{i+1}(X)=Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1} \right);\left(r_0,\cdots,r_{i-1}\right)\right) +/// \cdot Eq\left(\alpha_i ;X\right)\sum_{x}Eq\left(\left(\alpha_{i+1},\cdots\alpha_{\nu - 1}\right);\left( x_{i+1}, \cdots x_{\nu - 1}\right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// This means that $v_{i+1}(X)$ is the product of: +/// +/// 1. A constant polynomial: $Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right);\left( r_0, \cdots, r_{i-1} \right) \right)$ +/// 2. A linear polynomial: $Eq\left( \alpha_i ; X \right)$ +/// 3. A high degree polynomial: $\sum_{x} +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$ +/// +/// The advantage of the above decomposition is that the prover when computing $v_{i+1}(X)$ needs to sum over +/// +/// $$ +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// instead of +/// +/// $$ +/// Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1}, \alpha_i, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// which has the advantage of being of degree $1$ less and hence requires less work on the part of the prover. +/// +/// Thus, the prover computes the following polynomial +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// and then scales it in order to get +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right) \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// As the prover computes $v_{i+1}^{'}(X)$ in evaluation form and hence also $v_{i+1}(X)$, this +/// means that due to the degrees being off by $1$, the prover uses the linear factor in order to +/// obtain an additional evaluation point in order to be able to interpolate $v_{i+1}(X)$. +/// More precisely, we can get a root of $$v_{i+1}(X) = 0$$ by solving $$Eq\left( \alpha_i ; X \right) = 0$$ +/// The latter equation has as solution $$\mathsf{r} = \frac{1 - \alpha}{1 - 2\cdot\alpha}$$ +/// which is, except with negligible probability, an evaluation point not in the original +/// evaluation set and hence the prover is able to interpolate $v_{i+1}(X)$ and send it to +/// the verifier. +/// +/// Note that in order to avoid having to compute $\{Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ from $\{Eq\left( \left( \alpha_{i}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i}, \cdots x_{\nu - 1} \right) \right)\}$, or vice versa, we can write +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// as +/// +/// $$v_{i+1}^{'}(X) = \frac{1}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \sum_{x} +/// Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// Thus, $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ can be read from +/// $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{\nu - 1} \right);\left(x_{0}, \cdots x_{\nu - 1} \right) \right)\}$ +/// directly, at the cost of the relation between $v_{i+1}^{'}(X)$ and $v_{i+1}(X)$ becoming +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) \frac{Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right)}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// /// [1]: https://eprint.iacr.org/2023/1284 +/// [2]: https://eprint.iacr.org/2024/108 #[allow(clippy::too_many_arguments)] pub fn sum_check_prove_higher_degree< E: FieldElement, @@ -171,13 +266,14 @@ pub fn sum_check_prove_higher_degree< let mu = evaluator.get_num_fractions().trailing_zeros() - 1; let (evaluation_point_mu, evaluation_point_nu) = evaluation_point.split_at(mu as usize); let eq_mu = EqFunction::ml_at(evaluation_point_mu.into()).evaluations().to_vec(); - let mut eq_nu = EqFunction::ml_at(evaluation_point_nu.into()); + let eq_nu = EqFunction::ml_at(evaluation_point_nu.into()); // setup first round claim let mut current_round_claim = SumCheckRoundClaim { eval_point: vec![], claim }; // run the first round of the protocol - let round_poly_evals = sumcheck_round( + let mut round_poly_evals = sumcheck_round( + 0, &eq_mu, evaluator, &eq_nu, @@ -186,7 +282,21 @@ pub fn sum_check_prove_higher_degree< &log_up_randomness, r_sum_check, ); - let round_poly_coefs = round_poly_evals.to_poly(current_round_claim.claim); + + // this will hold `Eq((\alpha_0, \cdots, \alpha_{i - 1});(r_0, \cdots, r_{i-1}))` + let mut scaling_up_factor = E::ONE; + // this will hold `Eq((\alpha_{0}, \cdots, \alpha_{i}); (0, \cdots, 0))` for all `i` + let scaling_down_factors = compute_scaling_down_factors(evaluation_point_nu); + // this is `\alpha_i` above + let mut alpha_i = evaluation_point_nu[0]; + let scaling_down_factor = scaling_down_factors[0]; + let round_poly_coefs = to_coefficients( + &mut round_poly_evals, + current_round_claim.claim, + alpha_i, + scaling_down_factor, + scaling_up_factor, + ); // reseed with the s_0 polynomial coin.reseed(H::hash_elements(&round_poly_coefs.0)); @@ -197,6 +307,11 @@ pub fn sum_check_prove_higher_degree< let round_challenge = coin.draw().map_err(|_| SumCheckProverError::FailedToGenerateChallenge)?; + // update `scaling_up_factor` + alpha_i = evaluation_point_nu[evaluation_point_nu.len() - mls[0].num_variables()]; + scaling_up_factor *= + round_challenge * alpha_i + (E::ONE - round_challenge) * (E::ONE - alpha_i); + // compute the new reduced round claim let new_round_claim = reduce_claim(&round_proofs[i - 1], current_round_claim, round_challenge); @@ -204,14 +319,14 @@ pub fn sum_check_prove_higher_degree< // fold each multi-linear using the round challenge mls.iter_mut() .for_each(|ml| ml.bind_least_significant_variable(round_challenge)); - eq_nu.bind_least_significant_variable(round_challenge); // fold each periodic multi-linear using the round challenge periodic_table.bind_least_significant_variable(round_challenge); // run the i-th round of the protocol using the folded multi-linears for the new reduced // claim. This basically computes the s_i polynomial. - let round_poly_evals = sumcheck_round( + let mut round_poly_evals = sumcheck_round( + i, &eq_mu, evaluator, &eq_nu, @@ -224,7 +339,14 @@ pub fn sum_check_prove_higher_degree< // update the claim current_round_claim = new_round_claim; - let round_poly_coefs = round_poly_evals.to_poly(current_round_claim.claim); + let alpha_i = evaluation_point_nu[i]; + let round_poly_coefs = to_coefficients( + &mut round_poly_evals, + current_round_claim.claim, + alpha_i, + scaling_down_factors[i], + scaling_up_factor, + ); // reseed with the s_i polynomial coin.reseed(H::hash_elements(&round_poly_coefs.0)); @@ -239,7 +361,6 @@ pub fn sum_check_prove_higher_degree< // fold each multi-linear using the last random round challenge mls.iter_mut() .for_each(|ml| ml.bind_least_significant_variable(round_challenge)); - eq_nu.bind_least_significant_variable(round_challenge); let SumCheckRoundClaim { eval_point, claim: _claim } = reduce_claim(&round_proofs[num_rounds - 1], current_round_claim, round_challenge); @@ -293,7 +414,9 @@ pub fn sum_check_prove_higher_degree< /// the previous one using only additions. This is the purpose of `deltas`, to hold the increments /// added to each multi-linear to compute the evaluation at the next point, and `evals_x` to hold /// the current evaluation at $x$ in $\{2, ... , d_max\}$. +#[allow(clippy::too_many_arguments)] fn sumcheck_round( + sum_check_round: usize, eq_mu: &[E], evaluator: &impl LogUpGkrEvaluator::BaseField>, eq_ml: &MultiLinearPoly, @@ -301,7 +424,7 @@ fn sumcheck_round( periodic_table: &PeriodicTable, log_up_randomness: &[E], r_sum_check: E, -) -> CompressedUnivariatePolyEvals { +) -> Vec { let num_mls = mls.len(); let num_periodic = periodic_table.num_columns(); let num_vars = mls[0].num_variables(); @@ -316,47 +439,47 @@ fn sumcheck_round( let mut evals_periodic_one = vec![E::ZERO; num_periodic]; let mut evals_periodic_zero = vec![E::ZERO; num_periodic]; let mut evals_periodic_x = vec![E::ZERO; num_periodic]; - let mut eq_x = E::ZERO; let mut deltas = vec![E::ZERO; num_mls]; let mut deltas_periodic = vec![E::ZERO; num_periodic]; - let mut eq_delta = E::ZERO; let mut numerators = vec![E::ZERO; evaluator.get_num_fractions()]; let mut denominators = vec![E::ZERO; evaluator.get_num_fractions()]; (0..1 << num_rounds) .map(|i| { - let mut total_evals = vec![E::ZERO; evaluator.max_degree()]; + let mut poly_evals = vec![E::ZERO; evaluator.max_degree() - 1]; for (j, ml) in mls.iter().enumerate() { evals_zero[j] = ml.evaluations()[2 * i]; evals_one[j] = ml.evaluations()[2 * i + 1]; } - let eq_at_zero = eq_ml.evaluations()[2 * i]; - let eq_at_one = eq_ml.evaluations()[2 * i + 1]; - // add evaluation of periodic columns periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero); periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_one); - // compute the evaluation at 1 + // `(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1})` + let j = i << (sum_check_round + 1); + // `Eq((\alpha_{0}, \cdots, \alpha_{\nu - 1}); (0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1})) ` + let eq_at_zero = eq_ml.evaluations()[j]; + + // compute the evaluation at 0 evaluator.evaluate_query( - &evals_one, - &evals_periodic_one, + &evals_zero, + &evals_periodic_zero, log_up_randomness, &mut numerators, &mut denominators, ); - total_evals[0] = evaluate_composition_poly( + poly_evals[0] = evaluate_composition_poly( eq_mu, &numerators, &denominators, - eq_at_one, + eq_at_zero, r_sum_check, ); - // compute the evaluations at 2, ..., d_max points + // compute the evaluations at `2, ..., d_max - 1` points for i in 0..num_mls { deltas[i] = evals_one[i] - evals_zero[i]; evals_x[i] = evals_one[i]; @@ -365,10 +488,8 @@ fn sumcheck_round( deltas_periodic[i] = evals_periodic_one[i] - evals_periodic_zero[i]; evals_periodic_x[i] = evals_periodic_one[i]; } - eq_delta = eq_at_one - eq_at_zero; - eq_x = eq_at_one; - for e in total_evals.iter_mut().skip(1) { + for e in poly_evals.iter_mut().skip(1) { evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| { *evx += *delta; }); @@ -377,7 +498,6 @@ fn sumcheck_round( *evx += *delta; }, ); - eq_x += eq_delta; evaluator.evaluate_query( &evals_x, @@ -390,14 +510,14 @@ fn sumcheck_round( eq_mu, &numerators, &denominators, - eq_x, + eq_at_zero, r_sum_check, ); } - total_evals + poly_evals }) - .fold(vec![E::ZERO; evaluator.max_degree()], |mut acc, poly_eval| { + .fold(vec![E::ZERO; evaluator.max_degree() - 1], |mut acc, poly_eval| { acc.iter_mut().zip(poly_eval.iter()).for_each(|(a, b)| { *a += *b; }); @@ -421,7 +541,7 @@ fn sumcheck_round( vec![E::ZERO; evaluator.get_num_fractions()], vec![E::ZERO; num_mls], vec![E::ZERO; num_periodic], - vec![E::ZERO; evaluator.max_degree()], + vec![E::ZERO; evaluator.max_degree() - 1], ) }, |( @@ -443,17 +563,19 @@ fn sumcheck_round( evals_one[j] = ml.evaluations()[2 * i + 1]; } - let eq_at_zero = eq_ml.evaluations()[2 * i]; - let eq_at_one = eq_ml.evaluations()[2 * i + 1]; - // add evaluation of periodic columns periodic_table.fill_periodic_values_at(2 * i, &mut evals_periodic_zero); periodic_table.fill_periodic_values_at(2 * i + 1, &mut evals_periodic_one); - // compute the evaluation at 1 + // `(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1})` + let j = i << (sum_check_round + 1); + // `Eq((\alpha_{0}, \cdots, \alpha_{\nu - 1}); (0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1})) ` + let eq_at_zero = eq_ml.evaluations()[j]; + + // compute the evaluation at 0 evaluator.evaluate_query( - &evals_one, - &evals_periodic_one, + &evals_zero, + &evals_periodic_zero, log_up_randomness, &mut numerators, &mut denominators, @@ -462,11 +584,11 @@ fn sumcheck_round( eq_mu, &numerators, &denominators, - eq_at_one, + eq_at_zero, r_sum_check, ); - // compute the evaluations at 2, ..., d_max points + // compute the evaluations at `2, ..., d_max - 1` points for i in 0..num_mls { deltas[i] = evals_one[i] - evals_zero[i]; evals_x[i] = evals_one[i]; @@ -475,8 +597,6 @@ fn sumcheck_round( deltas_periodic[i] = evals_periodic_one[i] - evals_periodic_zero[i]; evals_periodic_x[i] = evals_periodic_one[i]; } - let eq_delta = eq_at_one - eq_at_zero; - let mut eq_x = eq_at_one; for e in poly_evals.iter_mut().skip(1) { evals_x.iter_mut().zip(deltas.iter()).for_each(|(evx, delta)| { @@ -487,7 +607,6 @@ fn sumcheck_round( *evx += *delta; }, ); - eq_x += eq_delta; evaluator.evaluate_query( &evals_x, @@ -500,7 +619,7 @@ fn sumcheck_round( eq_mu, &numerators, &denominators, - eq_x, + eq_at_zero, r_sum_check, ); } @@ -522,7 +641,7 @@ fn sumcheck_round( ) .map(|(.., poly_evals)| poly_evals) .reduce( - || vec![E::ZERO; evaluator.max_degree()], + || vec![E::ZERO; evaluator.max_degree() - 1], |mut acc, poly_eval| { acc.iter_mut().zip(poly_eval.iter()).for_each(|(a, b)| { *a += *b; @@ -531,7 +650,7 @@ fn sumcheck_round( }, ); - CompressedUnivariatePolyEvals(evaluations.into()) + evaluations } /// Reduces an old claim to a new claim using the round challenge. diff --git a/sumcheck/src/prover/mod.rs b/sumcheck/src/prover/mod.rs index 13d35e551..66a110ee6 100644 --- a/sumcheck/src/prover/mod.rs +++ b/sumcheck/src/prover/mod.rs @@ -4,10 +4,77 @@ // LICENSE file in the root directory of this source tree. mod high_degree; +use alloc::vec::Vec; + pub use high_degree::sum_check_prove_higher_degree; mod plain; +use math::{batch_inversion, FieldElement}; pub use plain::sumcheck_prove_plain; mod error; pub use error::SumCheckProverError; + +use crate::CompressedUnivariatePoly; + +/// Takes the evaluation of the polynomial $v_{i+1}^{'}(X)$ defined by +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// and computes the interpolation of the $v_{i+1}(X)$ polynomial defined by +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) \frac{Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right)}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// The function returns a `CompressedUnivariatePoly` instead of the full list of coefficients. +fn to_coefficients( + round_poly_evals: &mut [E], + claim: E, + alpha: E, + scaling_down_factor: E, + scaling_up_factor: E, +) -> CompressedUnivariatePoly { + let a = scaling_down_factor; + round_poly_evals.iter_mut().for_each(|e| *e *= scaling_up_factor); + + let mut round_poly_evaluations = Vec::with_capacity(round_poly_evals.len() + 1); + round_poly_evaluations.push(round_poly_evals[0] * compute_weight(alpha, E::ZERO) * a); + round_poly_evaluations.push(claim - round_poly_evaluations[0]); + + for (x, eval) in round_poly_evals.iter().skip(1).enumerate() { + round_poly_evaluations.push(*eval * compute_weight(alpha, E::from(x as u32 + 2)) * a) + } + + let root = (E::ONE - alpha) / (E::ONE - alpha.double()); + + CompressedUnivariatePoly::interpolate_equidistant_points(&round_poly_evaluations, root) +} + +/// Computes +/// +/// $$ +/// Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right) +/// $$ +/// +/// given $(\alpha_0, \cdots, \alpha_{\nu - 1})$ for all $i$ in $0, \cdots, \nu - 1$. +fn compute_scaling_down_factors(gkr_point: &[E]) -> Vec { + let cumulative_product: Vec = gkr_point + .iter() + .scan(E::ONE, |acc, &x| { + *acc *= E::ONE - x; + Some(*acc) + }) + .collect(); + batch_inversion(&cumulative_product) +} + +/// Computes $EQ(x; \alpha)$. +fn compute_weight(alpha: E, x: E) -> E { + x * alpha + (E::ONE - x) * (E::ONE - alpha) +} diff --git a/sumcheck/src/prover/plain.rs b/sumcheck/src/prover/plain.rs index e0092cf10..8e4766b6a 100644 --- a/sumcheck/src/prover/plain.rs +++ b/sumcheck/src/prover/plain.rs @@ -7,13 +7,9 @@ use crypto::{ElementHasher, RandomCoin}; use math::FieldElement; #[cfg(feature = "concurrent")] pub use rayon::prelude::*; -use smallvec::smallvec; -use super::SumCheckProverError; -use crate::{ - comb_func, CompressedUnivariatePolyEvals, FinalOpeningClaim, MultiLinearPoly, RoundProof, - SumCheckProof, -}; +use super::{compute_scaling_down_factors, to_coefficients, SumCheckProverError}; +use crate::{comb_func, FinalOpeningClaim, MultiLinearPoly, RoundProof, SumCheckProof}; /// Sum-check prover for non-linear multivariate polynomial of the simple LogUp-GKR. /// @@ -46,10 +42,106 @@ use crate::{ /// /// Note that the degree of the non-linear composition polynomial is 3. /// +/// +/// We now discuss a further optimization due to [2]. Suppose that we have a sum-check statment of +/// the following form: +/// +/// $$v_0=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{\nu - 1}\right);\left( x_0, \cdots, x_{\nu - 1}\right)\right) +/// C\left( x_0, \cdots, x_{\nu - 1} \right)$$ +/// +/// Then during round $i + 1$ of sum-check, the prover needs to send the following polynomial +/// +/// $$v_{i+1}(X)=\sum_{x}Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1},\alpha_i, \alpha_{i+1},\cdots\alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// We can write $v_{i+1}(X)$ as: +/// +/// $$v_{i+1}(X)=Eq\left(\left(\alpha_0,\cdots,\alpha_{i - 1} \right);\left(r_0,\cdots,r_{i-1}\right)\right) +/// \cdot Eq\left(\alpha_i ;X\right)\sum_{x}Eq\left(\left(\alpha_{i+1},\cdots\alpha_{\nu - 1}\right);\left( x_{i+1}, \cdots x_{\nu - 1}\right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// This means that $v_{i+1}(X)$ is the product of: +/// +/// 1. A constant polynomial: $Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right);\left( r_0, \cdots, r_{i-1} \right) \right)$ +/// 2. A linear polynomial: $Eq\left( \alpha_i ; X \right)$ +/// 3. A high degree polynomial: $\sum_{x} +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$ +/// +/// The advantage of the above decomposition is that the prover when computing $v_{i+1}(X)$ needs to sum over +/// +/// $$ +/// Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right);\left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// instead of +/// +/// $$ +/// Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1}, \alpha_i, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right) +/// $$ +/// +/// which has the advantage of being of degree $1$ less and hence requires less work on the part of the prover. +/// +/// Thus, the prover computes the following polynomial +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// and then scales it in order to get +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right) \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// As the prover computes $v_{i+1}^{'}(X)$ in evaluation form and hence also $v_{i+1}(X)$, this +/// means that due to the degrees being off by $1$, the prover uses the linear factor in order to +/// obtain an additional evaluation point in order to be able to interpolate $v_{i+1}(X)$. +/// More precisely, we can get a root of $$v_{i+1}(X) = 0$$ by solving $$Eq\left( \alpha_i ; X \right) = 0$$ +/// The latter equation has as solution $$\mathsf{r} = \frac{1 - \alpha}{1 - 2\cdot\alpha}$$ +/// which is, except with negligible probability, an evaluation point not in the original +/// evaluation set and hence the prover is able to interpolate $v_{i+1}(X)$ and send it to +/// the verifier. +/// +/// Note that in order to avoid having to compute $\{Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ from $\{Eq\left( \left( \alpha_{i}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i}, \cdots x_{\nu - 1} \right) \right)\}$, or vice versa, we can write +/// +/// $$v_{i+1}^{'}(X) = \sum_{x} Eq\left( \left( \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left( x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// as +/// +/// $$v_{i+1}^{'}(X) = \frac{1}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \sum_{x} +/// Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right) +/// C\left( r_0, \cdots, r_{i-1}, X, x_{i+1}, \cdots x_{\nu - 1} \right)$$ +/// +/// Thus, $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i}, \alpha_{i+1}, \cdots \alpha_{\nu - 1} \right); +/// \left(0, \cdots, 0, x_{i+1}, \cdots x_{\nu - 1} \right) \right)\}$ can be read from +/// $\{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{\nu - 1} \right);\left(x_{0}, \cdots x_{\nu - 1} \right) \right)\}$ +/// directly, at the cost of the relation between $v_{i+1}^{'}(X)$ and $v_{i+1}(X)$ becoming +/// +/// $$ +/// v_{i+1}(X) = v_{i+1}^{'}(X) \frac{Eq\left( \left(\alpha_0, \cdots, \alpha_{i - 1} \right); +/// \left( r_0, \cdots, r_{i-1} \right) \right)}{Eq\left( \left( \alpha_{0}, \cdots, \alpha_{i} \right); +/// \left(0, \cdots, 0\right) \right)} \cdot Eq\left( \alpha_i ; X \right) +/// $$ +/// +/// /// [1]: https://eprint.iacr.org/2023/1284 +/// [2]: https://eprint.iacr.org/2024/108 #[allow(clippy::too_many_arguments)] pub fn sumcheck_prove_plain>( mut claim: E, + gkr_point: &[E], r_batch: E, p: MultiLinearPoly, q: MultiLinearPoly, @@ -64,124 +156,83 @@ pub fn sumcheck_prove_plain CompressedUnivariatePoly { // evaluate polynom::eval(&complete_coefficients, *challenge) } -} - -impl Serializable for CompressedUnivariatePoly { - fn write_into(&self, target: &mut W) { - let vector: Vec = self.0.clone().into_vec(); - vector.write_into(target); - } -} -impl Deserializable for CompressedUnivariatePoly -where - E: FieldElement, -{ - fn read_from(source: &mut R) -> Result { - let vector: Vec = Vec::::read_from(source)?; - Ok(Self(vector.into())) - } -} + /// Given the evaluations of a polynomial over the set $0, 1, \cdots, d - 1$ and a `root` not in + /// the interpolation set, computes its coefficients. + pub fn interpolate_equidistant_points(ys: &[E], root: E) -> CompressedUnivariatePoly { + // we factor out the term `(x - r)` where `r` is the root + let quotient: Vec = (0..ys.len()).map(|i| E::from(i as u32) - root).collect(); + let quotient_inv = batch_inversion("ient); + let mut ys: Vec = ys.iter().zip(quotient_inv.iter()).map(|(&y, &q)| y * q).collect(); -/// The evaluations of a univariate polynomial of degree n at 0, 1, ..., n with the evaluation at 0 -/// omitted. -/// -/// This compressed representation is useful during the sum-check protocol as the full uncompressed -/// representation can be recovered from the compressed one and the current sum-check round claim. -#[derive(Clone, Debug)] -pub struct CompressedUnivariatePolyEvals(pub(crate) SmallVec<[E; MAX_POLY_SIZE]>); + // the zeroth coefficient can be recovered immediately + let c0 = ys.remove(0); -impl CompressedUnivariatePolyEvals { - /// Gives the coefficient representation of a polynomial represented in evaluation form. - /// - /// Since the evaluation at 0 is omitted, we need to use the round claim to recover - /// the evaluation at 0 using the identity $p(0) + p(1) = claim$. - /// Now, we have that for any polynomial $p(x) = c0 + c1 * x + ... + c_{n-1} * x^{n - 1}$: - /// - /// 1. $p(0) = c0$. - /// 2. $p(x) = c0 + x * q(x) where q(x) = c1 + ... + c_{n-1} * x^{n - 2}$. - /// - /// This means that we can compute the evaluations of q at 1, ..., n - 1 using the evaluations - /// of p and thus reduce by 1 the size of the interpolation problem. - /// Once the coefficient of q are recovered, the c0 coefficient is appended to these and this - /// is precisely the coefficient representation of the original polynomial q. - /// Note that the coefficient of the linear term is removed as this coefficient can be recovered - /// from the remaining coefficients, again, using the round claim using the relation - /// $2 * c0 + c1 + ... c_{n - 1} = claim$. - pub fn to_poly(&self, round_claim: E) -> CompressedUnivariatePoly { - // construct the vector of interpolation points 1, ..., n - let n_minus_1 = self.0.len(); + // build the interpolation set + let n_minus_1 = ys.len(); let points = (1..=n_minus_1 as u32).map(E::BaseField::from).collect::>(); // construct their inverses. These will be needed for computing the evaluations - // of the q polynomial as well as for doing the interpolation on q + // of the q polynomial as well as for doing the interpolation on q where q is + // defined as $p(x) = c0 + x * q(x) where q(x) = c1 + ... + c_{n-1} * x^{n - 2}$ let points_inv = batch_inversion(&points); - // compute the zeroth coefficient - let c0 = round_claim - self.0[0]; - // compute the evaluations of q - let q_evals: Vec = self - .0 + let q_evals: Vec = ys .iter() .enumerate() .map(|(i, evals)| (*evals - c0).mul_base(points_inv[i])) @@ -118,11 +82,34 @@ impl CompressedUnivariatePolyEvals { // append c0 to the coefficients of q to get the coefficients of p. The linear term // coefficient is removed as this can be recovered from the other coefficients using // the reduced claim. - let mut coefficients = SmallVec::with_capacity(self.0.len() + 1); + let mut coefficients = SmallVec::<[E; MAX_POLY_SIZE]>::with_capacity(ys.len() + 1); coefficients.push(c0); - coefficients.extend_from_slice(&q_coefs[1..]); + coefficients.extend_from_slice(&q_coefs[..]); + + // multiply back the factor `(x - r)` + let mut p_coefficients = polynom::mul(&coefficients, &[-root, E::ONE]); + + // remove the linear factor as it can be recovered from the `claim` and the other factors + p_coefficients.remove(1); + + CompressedUnivariatePoly(p_coefficients.into()) + } +} - CompressedUnivariatePoly(coefficients) +impl Serializable for CompressedUnivariatePoly { + fn write_into(&self, target: &mut W) { + let vector: Vec = self.0.clone().into_vec(); + vector.write_into(target); + } +} + +impl Deserializable for CompressedUnivariatePoly +where + E: FieldElement, +{ + fn read_from(source: &mut R) -> Result { + let vector: Vec = Vec::::read_from(source)?; + Ok(Self(vector.into())) } } @@ -259,22 +246,18 @@ fn test_poly_partial() { use math::fields::f64::BaseElement; let degree = 1000; - let mut points: Vec = vec![BaseElement::ZERO; degree]; - points - .iter_mut() - .enumerate() - .for_each(|(i, node)| *node = BaseElement::from(i as u32)); + // compute the claim let p: Vec = rand_utils::rand_vector(degree); - let evals = polynom::eval_many(&p, &points); - - let mut partial_evals = evals.clone(); - partial_evals.remove(0); - - let partial_poly = CompressedUnivariatePolyEvals(partial_evals.into()); + let evals = polynom::eval_many(&p, &[BaseElement::ZERO, BaseElement::ONE]); let claim = evals[0] + evals[1]; - let poly_coeff = partial_poly.to_poly(claim); + // build compressed polynomial + let mut poly_coeff = p.clone(); + poly_coeff.remove(1); + let poly_coeff = CompressedUnivariatePoly(poly_coeff.into()); + + // generate random challenge let r = rand_utils::rand_vector(1); assert_eq!(polynom::eval(&p, r[0]), poly_coeff.evaluate_using_claim(&claim, &r[0]))