diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 963708fd4..bdbbef140 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -145,12 +145,16 @@ jobs: - name: Run bn254_curve_syscall (release) env: RUSTFLAGS: "-C opt-level=3" - run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno examples/target/riscv32im-ceno-zkvm-elf/release/examples/bn254_curve_syscalls + run: | + ulimit -s 65536 + cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno examples/target/riscv32im-ceno-zkvm-elf/release/examples/bn254_curve_syscalls - name: Run bn254_fptower_syscalls (release) env: RUSTFLAGS: "-C opt-level=3" - run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno examples/target/riscv32im-ceno-zkvm-elf/release/examples/bn254_fptower_syscalls + run: | + ulimit -s 65536 + cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno examples/target/riscv32im-ceno-zkvm-elf/release/examples/bn254_fptower_syscalls - name: Run k256 ecrecover (release) env: diff --git a/Cargo.lock b/Cargo.lock index 9317978be..acbb1604a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2237,7 +2237,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.28#eda8eab4c18dfc20d14b464dfe8484e8f0498347" dependencies = [ "once_cell", "p3", @@ -3243,7 +3243,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.28#eda8eab4c18dfc20d14b464dfe8484e8f0498347" dependencies = [ "bincode 1.3.3", "clap", @@ -3267,7 +3267,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.28#eda8eab4c18dfc20d14b464dfe8484e8f0498347" dependencies = [ "either", "ff_ext", @@ -4558,7 +4558,7 @@ dependencies = [ [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.28#eda8eab4c18dfc20d14b464dfe8484e8f0498347" dependencies = [ "p3-air", "p3-baby-bear", @@ -5126,7 +5126,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.28#eda8eab4c18dfc20d14b464dfe8484e8f0498347" dependencies = [ "ff_ext", "p3", @@ -6083,7 +6083,7 @@ dependencies = [ [[package]] name = "sp1-curves" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.28#eda8eab4c18dfc20d14b464dfe8484e8f0498347" dependencies = [ "cfg-if", "dashu", @@ -6208,7 +6208,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.28#eda8eab4c18dfc20d14b464dfe8484e8f0498347" dependencies = [ "either", "ff_ext", @@ -6226,7 +6226,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.28#eda8eab4c18dfc20d14b464dfe8484e8f0498347" dependencies = [ "itertools 0.13.0", "p3", @@ -6633,7 +6633,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.28#eda8eab4c18dfc20d14b464dfe8484e8f0498347" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -6927,7 +6927,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.28#eda8eab4c18dfc20d14b464dfe8484e8f0498347" dependencies = [ "bincode 1.3.3", "clap", @@ -7214,7 +7214,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.25#f56aea652c0ba8a510bac04af9587ad500ed4a1c" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.28#eda8eab4c18dfc20d14b464dfe8484e8f0498347" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index 00e4b0641..597f0208a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,16 +27,16 @@ version = "0.1.0" ceno_crypto_primitives = { git = "https://github.com/scroll-tech/ceno-patch.git", package = "ceno_crypto_primitives", branch = "main" } ceno_syscall = { git = "https://github.com/scroll-tech/ceno-patch.git", package = "ceno_syscall", branch = "main" } -ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", tag = "v1.0.0-alpha.25" } -mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", tag = "v1.0.0-alpha.25" } -multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", tag = "v1.0.0-alpha.25" } -p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", tag = "v1.0.0-alpha.25" } -poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", tag = "v1.0.0-alpha.25" } -sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", tag = "v1.0.0-alpha.25" } -sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", tag = "v1.0.0-alpha.25" } -transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", tag = "v1.0.0-alpha.25" } -whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", tag = "v1.0.0-alpha.25" } -witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", tag = "v1.0.0-alpha.25" } +ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", tag = "v1.0.0-alpha.28" } +mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", tag = "v1.0.0-alpha.28" } +multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", tag = "v1.0.0-alpha.28" } +p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", tag = "v1.0.0-alpha.28" } +poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", tag = "v1.0.0-alpha.28" } +sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", tag = "v1.0.0-alpha.28" } +sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", tag = "v1.0.0-alpha.28" } +transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", tag = "v1.0.0-alpha.28" } +whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", tag = "v1.0.0-alpha.28" } +witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", tag = "v1.0.0-alpha.28" } anyhow = { version = "1.0", default-features = false } bincode = "1" diff --git a/ceno_recursion/src/zkvm_verifier/binding.rs b/ceno_recursion/src/zkvm_verifier/binding.rs index 17d8f0bcd..094b812c5 100644 --- a/ceno_recursion/src/zkvm_verifier/binding.rs +++ b/ceno_recursion/src/zkvm_verifier/binding.rs @@ -74,6 +74,7 @@ pub struct ZKVMProofInputVariable { pub shard_id: Usize, pub pi: Array>, pub chip_proofs: Array>>, + pub main_constraint_proof: SumcheckLayerProofVariable, pub max_num_var: Var, pub max_width: Var, pub witin_commit: BasefoldCommitmentVariable, @@ -96,6 +97,7 @@ pub(crate) struct ZKVMProofInput { pub shard_id: usize, pub pi: Vec, pub chip_proofs: BTreeMap, + pub main_constraint_proof: SumcheckLayerProofInput, pub witin_commit: BasefoldCommitment, pub opening_proof: BasefoldProof, } @@ -143,18 +145,48 @@ impl ZKVMProofInput { let (num_witin, num_fixed) = *chip_witin_num_vars .get(&chip_idx) .expect("num_witin data should exist"); + let composed_cs = vk + .circuit_vks + .values() + .nth(chip_idx) + .expect("chip vk should exist") + .get_cs(); ( chip_idx, proofs .into_iter() .map(|proof| { - ZKVMChipProofInput::from((chip_idx, proof, num_witin, num_fixed)) + let sum_num_instances = proof.num_instances.iter().sum::(); + let mut num_vars = + ceil_log2(next_pow2_instance_padding(sum_num_instances)) + + composed_cs.rotation_vars().unwrap_or(0); + if composed_cs.has_ecc_ops() { + num_vars += 1; + } + ZKVMChipProofInput::from(( + chip_idx, proof, num_vars, num_witin, num_fixed, + )) }) .collect::>() .into(), ) }) .collect::>(), + main_constraint_proof: SumcheckLayerProofInput { + proof: IOPProverMessageVec::from( + zkvm_proof + .main_constraint_proof + .proof + .proof + .proofs + .iter() + .map(|p| IOPProverMessage { + evaluations: p.evaluations.clone(), + }) + .collect::>(), + ), + evals: zkvm_proof.main_constraint_proof.proof.evals, + }, witin_commit: zkvm_proof.witin_commit.into(), opening_proof: zkvm_proof.opening_proof.into(), } @@ -170,6 +202,9 @@ impl Hintable for ZKVMProofInput { builder.cycle_tracker_start("read chip proofs"); let chip_proofs = Vec::::read(builder); builder.cycle_tracker_end("read chip proofs"); + builder.cycle_tracker_start("read main constraint proof"); + let main_constraint_proof = SumcheckLayerProofInput::read(builder); + builder.cycle_tracker_end("read main constraint proof"); let max_num_var = usize::read(builder); let max_width = usize::read(builder); let witin_commit = BasefoldCommitment::read(builder); @@ -184,6 +219,7 @@ impl Hintable for ZKVMProofInput { shard_id, pi, chip_proofs, + main_constraint_proof, max_num_var, max_width, witin_commit, @@ -257,6 +293,7 @@ impl Hintable for ZKVMProofInput { for proofs in self.chip_proofs.values() { stream.extend(proofs.write()); } + stream.extend(self.main_constraint_proof.write()); stream.extend(>::write(&max_num_var)); stream.extend(>::write(&max_width)); stream.extend(self.witin_commit.write()); @@ -427,8 +464,8 @@ impl Hintable for ZKVMChipProofs { } } -impl From<(usize, ZKVMChipProof, usize, usize)> for ZKVMChipProofInput { - fn from(d: (usize, ZKVMChipProof, usize, usize)) -> Self { +impl From<(usize, ZKVMChipProof, usize, usize, usize)> for ZKVMChipProofInput { + fn from(d: (usize, ZKVMChipProof, usize, usize, usize)) -> Self { let idx = d.0; let p = d.1; @@ -443,14 +480,14 @@ impl From<(usize, ZKVMChipProof, usize, usize)> for ZKVMChipProofInput { .collect::>(); vars[0] } else { - 0 + d.2 }; Self { idx, num_vars, - num_witin: d.2, - num_fixed: d.3, + num_witin: d.3, + num_fixed: d.4, r_out_evals_len: p.r_out_evals.len(), w_out_evals_len: p.w_out_evals.len(), lk_out_evals_len: p.lk_out_evals.len(), diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index 14901139e..665f7556a 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -7,8 +7,8 @@ use crate::{ PolyEvaluator, UniPolyExtrapolator, arr_product, assert_ext_arr_eq, build_eq_x_r_vec_sequential, challenger_hint_observe, challenger_multi_observe, concat, dot_product as ext_dot_product, eq_eval, eq_eval_less_or_equal_than, - eval_ceno_expr_with_instance, eval_wellform_address_vec, exts_to_felts, gen_alpha_pows, - mask_arr, reverse, + eval_ceno_expr_with_instance, eval_stacked_constant, eval_stacked_wellform_address_vec, + eval_wellform_address_vec, exts_to_felts, gen_alpha_pows, mask_arr, reverse, }, basefold_verifier::{ basefold::{BasefoldCommitmentVariable, RoundOpeningVariable, RoundVariable}, @@ -18,10 +18,12 @@ use crate::{ verifier::batch_verify, }, tower_verifier::{ - binding::{PointAndEvalVariable, PointVariable}, + binding::{IOPProverMessageVecVariable, PointAndEvalVariable, PointVariable}, program::{iop_verifier_state_verify, verify_tower_proof}, }, - transcript::transcript_observe_label, + transcript::{ + transcript_label_as_array, transcript_observe_label, transcript_observe_label_felts, + }, zkvm_verifier::binding::{ EccQuarkProofVariable, GKRProofVariable, LayerProofVariable, SelectorContextVariable, SepticExtensionVariable, SepticPointVariable, SumcheckLayerProofVariable, @@ -56,6 +58,17 @@ type Pcs = Basefold; const NUM_FANIN: usize = 2; const SEPTIC_EXTENSION_DEGREE: usize = 7; +#[derive(DslVariable, Clone)] +pub struct PendingChipMainVariable { + pub num_var_with_rotation: Usize, + pub out_evals: Array>, + pub pi_evals: Array>, + pub selector_ctxs: Array>, + pub shard_ec_sum: SepticPointVariable, + pub prod_out_evals: Array>, + pub logup_out_evals: Array>, +} + pub fn transcript_group_observe_label( builder: &mut Builder, challenger_group: &mut Vec>, @@ -249,6 +262,21 @@ pub fn verify_zkvm_proof>( // collect fork sampling result let forked_samples: Array> = builder.dyn_array(proofs_len.get_var()); let forked_sample_index: Usize = builder.eval(C::N::ZERO); + let max_out_evals = max_first_layer_out_evals(vk); + let max_selector_ctxs = max_first_layer_selector_ctxs(vk); + let max_pi_evals = max_pi_evals(vk); + let pending_num_vars: Array> = builder.dyn_array(proofs_len.clone()); + let pending_out_evals_len: Usize = + builder.eval(proofs_len.clone() * Usize::from(max_out_evals)); + let pending_out_evals: Array> = + builder.dyn_array(pending_out_evals_len); + let pending_selector_ctxs_len: Usize = + builder.eval(proofs_len.clone() * Usize::from(max_selector_ctxs)); + let pending_selector_ctxs: Array> = + builder.dyn_array(pending_selector_ctxs_len); + let pending_pi_evals_len: Usize = + builder.eval(proofs_len.clone() * Usize::from(max_pi_evals)); + let pending_pi_evals: Array> = builder.dyn_array(pending_pi_evals_len); for (i, (circuit_name, chip_vk)) in vk.circuit_vks.iter().enumerate() { let circuit_vk = &vk.circuit_vks[circuit_name]; @@ -306,12 +334,7 @@ pub fn verify_zkvm_proof>( ); builder.cycle_tracker_start("Verify chip proof"); - let ( - input_opening_point, - chip_shard_ec_sum, - chip_prod_out_evals, - chip_logup_out_evals, - ) = verify_chip_proof( + let pending_chip_main = verify_chip_proof_pre_main( circuit_name, builder, &mut chip_challenger, @@ -324,6 +347,42 @@ pub fn verify_zkvm_proof>( ); builder.cycle_tracker_end("Verify chip proof"); + builder.set( + &pending_num_vars, + forked_sample_index.get_var(), + pending_chip_main.num_var_with_rotation.clone(), + ); + let out_offset: Usize = + builder.eval(forked_sample_index.clone() * Usize::from(max_out_evals)); + builder + .range(0, pending_chip_main.out_evals.len()) + .for_each(|idx_vec, builder| { + let idx = idx_vec[0]; + let dst: Usize = builder.eval(out_offset.clone() + idx); + let value = builder.get(&pending_chip_main.out_evals, idx); + builder.set(&pending_out_evals, dst, value); + }); + let selector_offset: Usize = + builder.eval(forked_sample_index.clone() * Usize::from(max_selector_ctxs)); + builder + .range(0, pending_chip_main.selector_ctxs.len()) + .for_each(|idx_vec, builder| { + let idx = idx_vec[0]; + let dst: Usize = builder.eval(selector_offset.clone() + idx); + let value = builder.get(&pending_chip_main.selector_ctxs, idx); + builder.set(&pending_selector_ctxs, dst, value); + }); + let pi_offset: Usize = + builder.eval(forked_sample_index.clone() * Usize::from(max_pi_evals)); + builder + .range(0, pending_chip_main.pi_evals.len()) + .for_each(|idx_vec, builder| { + let idx = idx_vec[0]; + let dst: Usize = builder.eval(pi_offset.clone() + idx); + let value = builder.get(&pending_chip_main.pi_evals, idx); + builder.set(&pending_pi_evals, dst, value); + }); + let chip_logup_sum: Ext = builder.constant(C::EF::ZERO); builder .range(0, chip_proof.lk_out_evals_len.clone()) @@ -333,7 +392,7 @@ pub fn verify_zkvm_proof>( let end: Usize = builder.eval(start.clone() + C::N::from_canonical_usize(4)); - let evals = chip_logup_out_evals.slice(builder, start, end); + let evals = pending_chip_main.logup_out_evals.slice(builder, start, end); let p1 = builder.get(&evals, 0); let p2 = builder.get(&evals, 1); let q1 = builder.get(&evals, 2); @@ -345,60 +404,32 @@ pub fn verify_zkvm_proof>( builder.assign(&logup_sum, logup_sum + chip_logup_sum); - let point_clone: Array> = - builder.eval(input_opening_point.clone()); - let (wits_in_evals, fixed_in_evals) = split_input_opening_evals( - builder, - &chip_proof, - circuit_vk.get_cs().num_witin(), - circuit_vk.get_cs().num_fixed(), - ); - - if circuit_vk.get_cs().num_witin() > 0 { - let witin_round: RoundOpeningVariable = builder.eval(RoundOpeningVariable { - num_var: input_opening_point.len().get_var(), - point_and_evals: PointAndEvalsVariable { - point: PointVariable { fs: point_clone }, - evals: wits_in_evals, - }, - }); - builder.set_value(&witin_openings, num_witin_openings.get_var(), witin_round); - builder.inc(&num_witin_openings); - } - if circuit_vk.get_cs().num_fixed() > 0 { - let fixed_round: RoundOpeningVariable = builder.eval(RoundOpeningVariable { - num_var: input_opening_point.len().get_var(), - point_and_evals: PointAndEvalsVariable { - point: PointVariable { - fs: input_opening_point, - }, - evals: fixed_in_evals, - }, - }); - - builder.set_value(&fixed_openings, num_fixed_openings.get_var(), fixed_round); - builder.inc(&num_fixed_openings); - } - let r_out_evals_end: Usize = builder.eval(chip_proof.r_out_evals_len * Usize::from(2)); builder .range(0, r_out_evals_end.clone()) .for_each(|idx_vec, builder| { - let e = builder.get(&chip_prod_out_evals, idx_vec[0]); + let e = builder.get(&pending_chip_main.prod_out_evals, idx_vec[0]); builder.assign(&prod_r, prod_r * e); }); builder - .range(r_out_evals_end, chip_prod_out_evals.len()) + .range(r_out_evals_end, pending_chip_main.prod_out_evals.len()) .for_each(|idx_vec, builder| { - let e = builder.get(&chip_prod_out_evals, idx_vec[0]); + let e = builder.get(&pending_chip_main.prod_out_evals, idx_vec[0]); builder.assign(&prod_w, prod_w * e); }); builder - .if_ne(chip_shard_ec_sum.is_infinity.clone(), Usize::from(1)) + .if_ne( + pending_chip_main.shard_ec_sum.is_infinity.clone(), + Usize::from(1), + ) .then(|builder| { - add_septic_points_in_place(builder, &shard_ec_sum, &chip_shard_ec_sum); + add_septic_points_in_place( + builder, + &shard_ec_sum, + &pending_chip_main.shard_ec_sum, + ); }); let chip_sample = chip_challenger.sample_ext(builder); @@ -409,10 +440,6 @@ pub fn verify_zkvm_proof>( }); } - // truncate the witin and fixed opening arrays - witin_openings.truncate(builder, num_witin_openings); - fixed_openings.truncate(builder, num_fixed_openings); - // all proofs must be verified without missing builder.assert_eq::>(num_chips_verified, chip_indices.len()); @@ -430,6 +457,32 @@ pub fn verify_zkvm_proof>( challenger.observe_slice(builder, sample_felts); }); + verify_batched_main_constraints( + builder, + &mut challenger, + &zkvm_proof_input, + vk, + &challenges, + proofs_len.clone(), + &chip_indices, + &pending_num_vars, + &pending_out_evals, + &pending_selector_ctxs, + &pending_pi_evals, + max_out_evals, + max_selector_ctxs, + max_pi_evals, + &witin_openings, + &fixed_openings, + &num_witin_openings, + &num_fixed_openings, + &unipoly_extrapolator, + ); + + // truncate the witin and fixed opening arrays + witin_openings.truncate(builder, num_witin_openings); + fixed_openings.truncate(builder, num_fixed_openings); + let rounds: Array> = if num_fixed_opening > 0 { builder.dyn_array(2) } else { @@ -512,30 +565,7 @@ pub fn verify_zkvm_proof>( shard_ec_sum } -fn split_input_opening_evals( - builder: &mut Builder, - chip_proof: &ZKVMChipProofInputVariable, - num_witin: usize, - num_fixed: usize, -) -> (Array>, Array>) { - let last_layer_idx: Usize = - builder.eval(chip_proof.gkr_iop_proof.layer_proofs.len() - Usize::from(1)); - let last_layer = builder.get(&chip_proof.gkr_iop_proof.layer_proofs, last_layer_idx); - let main_evals = last_layer.main.evals; - - let wit_end = Usize::from(num_witin); - let fixed_end: Usize = builder.eval(wit_end.clone() + Usize::from(num_fixed)); - // Native verifier accepts extra trailing evals; only the prefix is consumed here. - // Keep recursion semantics aligned by slicing the required prefix. - let eval_prefix = main_evals.slice(builder, Usize::from(0), fixed_end.clone()); - - ( - eval_prefix.slice(builder, Usize::from(0), wit_end), - eval_prefix.slice(builder, Usize::from(num_witin), fixed_end), - ) -} - -pub fn verify_chip_proof( +pub fn verify_chip_proof_pre_main( circuit_name: &str, builder: &mut Builder, challenger: &mut DuplexChallengerVariable, @@ -544,13 +574,8 @@ pub fn verify_chip_proof( challenges: &Array>, vk: &VerifyingKey, unipoly_extrapolator: &UniPolyExtrapolator, - poly_evaluator: &mut PolyEvaluator, -) -> ( - Array>, - SepticPointVariable, - Array>, - Array>, -) { + _poly_evaluator: &mut PolyEvaluator, +) -> PendingChipMainVariable { let composed_cs = vk.get_cs(); let ComposedConstrainSystem { zkvm_v1_css: cs, @@ -703,11 +728,20 @@ pub fn verify_chip_proof( .clone(), ); - let mut selector_ctxs = Vec::with_capacity(first_layer.out_sel_and_eval_exprs.len()); - for (selector, _) in &first_layer.out_sel_and_eval_exprs { - let ctx = if cs.ec_final_sum.is_empty() { - let non_shard_n1 = Usize::Var(builder.get(&chip_proof.num_instances, 1)); - builder.assert_usize_eq(non_shard_n1, Usize::from(0)); + let group_has_tower = first_layer_output_group_has_tower(composed_cs, &gkr_circuit); + let selector_ctxs: Array> = + builder.dyn_array(first_layer.out_sel_and_eval_exprs.len()); + for (group_idx, ((selector, _), has_tower)) in first_layer + .out_sel_and_eval_exprs + .iter() + .zip(group_has_tower.iter()) + .enumerate() + { + let ctx = if !has_tower || cs.ec_final_sum.is_empty() { + if cs.ec_final_sum.is_empty() { + let non_shard_n1 = Usize::Var(builder.get(&chip_proof.num_instances, 1)); + builder.assert_usize_eq(non_shard_n1, Usize::from(0)); + } SelectorContextVariable { offset: Usize::from(0), offset_bit_decomps: zero_bit_decomps.clone(), @@ -766,7 +800,7 @@ pub fn verify_chip_proof( num_vars: num_var_with_rotation.clone(), } }; - selector_ctxs.push(ctx); + builder.set(&selector_ctxs, group_idx, ctx); } if !first_layer.rotation_exprs.1.is_empty() { @@ -861,170 +895,781 @@ pub fn verify_chip_proof( } } - if composed_cs.has_ecc_ops() { - let [ - x_group_idx, - y_group_idx, - slope_group_idx, - x3_group_idx, - y3_group_idx, - ] = first_layer - .ecc_bridge_group_indices() - .expect("ecc bridge selectors missing"); - - transcript_observe_label(builder, challenger, b"ecc_gkr_bridge_r"); - let sample_r: Ext = challenger.sample_ext(builder); - let one_minus_r: Ext = builder.eval(one - sample_r); - let ecc_proof = &chip_proof.ecc_proof; + PendingChipMainVariable { + num_var_with_rotation, + out_evals, + pi_evals: circuit_pi_evals, + selector_ctxs, + shard_ec_sum, + prod_out_evals, + logup_out_evals, + } +} - let xy_point_len: Usize = builder.eval(ecc_proof.rt.fs.len() + Usize::from(1)); - let xy_point: Array> = builder.dyn_array(xy_point_len); - builder.set(&xy_point, 0, sample_r); - builder - .range(0, ecc_proof.rt.fs.len()) - .for_each(|idx_vec, builder| { - let idx = idx_vec[0]; - let v = builder.get(&ecc_proof.rt.fs, idx); - let shifted_idx = Usize::Var(Var::uninit(builder)); - builder.assign(&shifted_idx, idx + Usize::from(1)); - builder.set(&xy_point, shifted_idx, v); - }); +fn tower_output_count(composed_cs: &ComposedConstrainSystem) -> usize { + let cs = &composed_cs.zkvm_v1_css; + let num_reads = cs.r_expressions.len() + cs.r_table_expressions.len(); + let num_writes = cs.w_expressions.len() + cs.w_table_expressions.len(); + let num_lk_num = cs.lk_table_expressions.len(); + let num_lk_den = if !cs.lk_table_expressions.is_empty() { + cs.lk_table_expressions.len() + } else { + cs.lk_expressions.len() + }; + num_reads + num_writes + num_lk_num + num_lk_den +} - let s_point_len: Usize = builder.eval(ecc_proof.rt.fs.len() + Usize::from(1)); - let s_point: Array> = builder.dyn_array(s_point_len.clone()); - builder - .range(0, ecc_proof.rt.fs.len()) - .for_each(|idx_vec, builder| { - let idx = idx_vec[0]; - let v = builder.get(&ecc_proof.rt.fs, idx); - builder.set(&s_point, idx, v); - }); - builder.set(&s_point, ecc_proof.rt.fs.len(), sample_r); +fn first_layer_output_group_has_tower( + composed_cs: &ComposedConstrainSystem, + circuit: &GKRCircuit, +) -> Vec { + let first_layer = circuit.layers.first().expect("empty gkr circuit layer"); + let tower_outputs = tower_output_count(composed_cs); + let mut remaining = tower_outputs; + first_layer + .out_sel_and_eval_exprs + .iter() + .map(|(_, outputs)| { + if remaining == 0 { + false + } else { + remaining = remaining.saturating_sub(outputs.len()); + true + } + }) + .collect() +} - let x3y3_point: Array> = builder.dyn_array(s_point_len.clone()); - builder - .range(0, ecc_proof.rt.fs.len()) - .for_each(|idx_vec, builder| { - let idx = idx_vec[0]; - let v = builder.get(&ecc_proof.rt.fs, idx); - builder.set(&x3y3_point, idx, v); - }); - builder.set(&x3y3_point, ecc_proof.rt.fs.len(), one); +fn max_first_layer_out_evals(vk: &ZKVMVerifyingKey) -> usize { + vk.circuit_vks + .values() + .filter_map(|chip_vk| chip_vk.get_cs().gkr_circuit.as_ref()) + .map(|circuit| circuit.n_evaluations) + .max() + .unwrap_or(0) +} - let degree = SEPTIC_EXTENSION_DEGREE; - for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[x_group_idx] - .1 - .iter() - .enumerate() - { - let EvalExpression::Single(out_idx) = eval_expr else { - panic!("ecc bridge x group must use EvalExpression::Single"); - }; - let x0 = builder.get(&ecc_proof.evals, 3 + degree + idx); - let x1 = builder.get(&ecc_proof.evals, 3 + degree * 3 + idx); - let eval = builder.eval(x0 * one_minus_r + x1 * sample_r); - let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { - point: PointVariable { - fs: xy_point.clone(), - }, - eval, - }); - builder.set(&out_evals, *out_idx, claim); - } +fn max_first_layer_selector_ctxs(vk: &ZKVMVerifyingKey) -> usize { + vk.circuit_vks + .values() + .filter_map(|chip_vk| chip_vk.get_cs().gkr_circuit.as_ref()) + .filter_map(|circuit| circuit.layers.first()) + .map(|layer| layer.out_sel_and_eval_exprs.len()) + .max() + .unwrap_or(0) +} - for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[y_group_idx] - .1 - .iter() - .enumerate() - { - let EvalExpression::Single(out_idx) = eval_expr else { - panic!("ecc bridge y group must use EvalExpression::Single"); - }; - let y0 = builder.get(&ecc_proof.evals, 3 + degree * 2 + idx); - let y1 = builder.get(&ecc_proof.evals, 3 + degree * 4 + idx); - let eval = builder.eval(y0 * one_minus_r + y1 * sample_r); - let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { - point: PointVariable { - fs: xy_point.clone(), - }, - eval, - }); - builder.set(&out_evals, *out_idx, claim); - } +fn max_pi_evals(vk: &ZKVMVerifyingKey) -> usize { + vk.circuit_vks + .values() + .map(|chip_vk| chip_vk.get_cs().zkvm_v1_css.instance.len()) + .max() + .unwrap_or(0) +} - for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[slope_group_idx] - .1 - .iter() - .enumerate() - { - let EvalExpression::Single(out_idx) = eval_expr else { - panic!("ecc bridge slope group must use EvalExpression::Single"); - }; - let s1 = builder.get(&ecc_proof.evals, 3 + idx); - let eval = builder.eval(s1 * sample_r); - let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { - point: PointVariable { - fs: s_point.clone(), - }, - eval, +fn assign_ecc_bridge_claims( + builder: &mut Builder, + first_layer: &Layer, + out_evals: &Array>, + ecc_proof: &EccQuarkProofVariable, + sample_r: Ext, +) { + let [ + x_group_idx, + y_group_idx, + slope_group_idx, + x3_group_idx, + y3_group_idx, + ] = first_layer + .ecc_bridge_group_indices() + .expect("ecc bridge selectors missing"); + let one: Ext = builder.constant(C::EF::ONE); + let one_minus_r: Ext = builder.eval(one - sample_r); + + let xy_point_len: Usize = builder.eval(ecc_proof.rt.fs.len() + Usize::from(1)); + let xy_point: Array> = builder.dyn_array(xy_point_len); + builder.set(&xy_point, 0, sample_r); + builder + .range(0, ecc_proof.rt.fs.len()) + .for_each(|idx_vec, builder| { + let idx = idx_vec[0]; + let v = builder.get(&ecc_proof.rt.fs, idx); + let shifted_idx = Usize::Var(Var::uninit(builder)); + builder.assign(&shifted_idx, idx + Usize::from(1)); + builder.set(&xy_point, shifted_idx, v); + }); + + let s_point_len: Usize = builder.eval(ecc_proof.rt.fs.len() + Usize::from(1)); + let s_point: Array> = builder.dyn_array(s_point_len.clone()); + builder + .range(0, ecc_proof.rt.fs.len()) + .for_each(|idx_vec, builder| { + let idx = idx_vec[0]; + let v = builder.get(&ecc_proof.rt.fs, idx); + builder.set(&s_point, idx, v); + }); + builder.set(&s_point, ecc_proof.rt.fs.len(), sample_r); + + let x3y3_point: Array> = builder.dyn_array(s_point_len); + builder + .range(0, ecc_proof.rt.fs.len()) + .for_each(|idx_vec, builder| { + let idx = idx_vec[0]; + let v = builder.get(&ecc_proof.rt.fs, idx); + builder.set(&x3y3_point, idx, v); + }); + builder.set(&x3y3_point, ecc_proof.rt.fs.len(), one); + + let degree = SEPTIC_EXTENSION_DEGREE; + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[x_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("ecc bridge x group must use EvalExpression::Single"); + }; + let x0 = builder.get(&ecc_proof.evals, 3 + degree + idx); + let x1 = builder.get(&ecc_proof.evals, 3 + degree * 3 + idx); + let eval = builder.eval(x0 * one_minus_r + x1 * sample_r); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: xy_point.clone(), + }, + eval, + }); + builder.set(out_evals, *out_idx, claim); + } + + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[y_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("ecc bridge y group must use EvalExpression::Single"); + }; + let y0 = builder.get(&ecc_proof.evals, 3 + degree * 2 + idx); + let y1 = builder.get(&ecc_proof.evals, 3 + degree * 4 + idx); + let eval = builder.eval(y0 * one_minus_r + y1 * sample_r); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: xy_point.clone(), + }, + eval, + }); + builder.set(out_evals, *out_idx, claim); + } + + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[slope_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("ecc bridge slope group must use EvalExpression::Single"); + }; + let s1 = builder.get(&ecc_proof.evals, 3 + idx); + let eval = builder.eval(s1 * sample_r); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: s_point.clone(), + }, + eval, + }); + builder.set(out_evals, *out_idx, claim); + } + + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[x3_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("ecc bridge x3 group must use EvalExpression::Single"); + }; + let eval = builder.get(&ecc_proof.evals, 3 + degree * 5 + idx); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: x3y3_point.clone(), + }, + eval, + }); + builder.set(out_evals, *out_idx, claim); + } + + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[y3_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("ecc bridge y3 group must use EvalExpression::Single"); + }; + let eval = builder.get(&ecc_proof.evals, 3 + degree * 6 + idx); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: x3y3_point.clone(), + }, + eval, + }); + builder.set(out_evals, *out_idx, claim); + } +} + +fn iop_verifier_state_verify_with_proof_degree( + builder: &mut Builder, + challenger: &mut DuplexChallengerVariable, + out_claim: &Ext, + prover_messages: &IOPProverMessageVecVariable, + max_num_variables: Felt, + unipoly_extrapolator: &UniPolyExtrapolator, +) -> ( + Array::F, ::EF>>, + Ext<::F, ::EF>, +) { + let max_num_variables_usize: Usize = + Usize::from(builder.cast_felt_to_var(max_num_variables)); + let max_degree_felt = builder.unsafe_cast_var_to_felt(prover_messages.prover_message_size); + let pre_verified_integrity_data: Array> = builder.dyn_array(4); + builder.set_value(&pre_verified_integrity_data, 0, max_num_variables); + builder.set_value(&pre_verified_integrity_data, 2, max_degree_felt); + challenger_multi_observe(builder, challenger, &pre_verified_integrity_data); + + builder.assert_var_eq(max_num_variables_usize.get_var(), prover_messages.len()); + + let challenges: Array> = builder.dyn_array(max_num_variables_usize.clone()); + let expected: Ext = builder.eval(*out_claim); + let internal_round_label = transcript_label_as_array(builder, b"Internal round"); + + let curr_offset: Var = builder.eval(C::N::ZERO); + builder + .range(0, max_num_variables_usize.clone()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let next_offset: Var = + builder.eval(curr_offset + prover_messages.prover_message_size); + let prover_msg = prover_messages + .evaluations + .slice(builder, curr_offset, next_offset); + builder.assign(&curr_offset, next_offset); + + unsafe { + let prover_msg_felts = exts_to_felts(builder, &prover_msg); + challenger_multi_observe(builder, challenger, &prover_msg_felts); + } + + transcript_observe_label_felts(builder, challenger, &internal_round_label); + let challenge = challenger.sample_ext(builder); + + let e1 = builder.get(&prover_msg, 0); + let e0 = builder.eval(expected - e1); + let p_r: Ext = builder.constant(C::EF::ZERO); + builder + .if_eq( + Usize::Var(prover_messages.prover_message_size), + Usize::from(1), + ) + .then(|builder| { + let value = + unipoly_extrapolator.extrapolate_uni_poly_deg_1(builder, e0, e1, challenge); + builder.assign(&p_r, value); + }); + builder + .if_eq( + Usize::Var(prover_messages.prover_message_size), + Usize::from(2), + ) + .then(|builder| { + let p1 = e1; + let p2 = builder.get(&prover_msg, 1); + let value = unipoly_extrapolator + .extrapolate_uni_poly_deg_2(builder, e0, p1, p2, challenge); + builder.assign(&p_r, value); + }); + builder + .if_eq( + Usize::Var(prover_messages.prover_message_size), + Usize::from(3), + ) + .then(|builder| { + let p1 = e1; + let p2 = builder.get(&prover_msg, 1); + let p3 = builder.get(&prover_msg, 2); + let value = unipoly_extrapolator + .extrapolate_uni_poly_deg_3(builder, e0, p1, p2, p3, challenge); + builder.assign(&p_r, value); + }); + builder + .if_eq( + Usize::Var(prover_messages.prover_message_size), + Usize::from(4), + ) + .then(|builder| { + let p1 = e1; + let p2 = builder.get(&prover_msg, 1); + let p3 = builder.get(&prover_msg, 2); + let p4 = builder.get(&prover_msg, 3); + let value = unipoly_extrapolator + .extrapolate_uni_poly_deg_4(builder, e0, p1, p2, p3, p4, challenge); + builder.assign(&p_r, value); + }); + + builder.assign(&expected, p_r); + builder.set_value(&challenges, i, challenge); + }); + + (challenges, expected) +} + +#[allow(clippy::too_many_arguments)] +fn verify_batched_main_constraints>( + builder: &mut Builder, + challenger: &mut DuplexChallengerVariable, + zkvm_proof_input: &ZKVMProofInputVariable, + vk: &ZKVMVerifyingKey, + challenges: &Array>, + proofs_len: Usize, + chip_indices: &Array>, + pending_num_vars: &Array>, + pending_out_evals: &Array>, + pending_selector_ctxs: &Array>, + pending_pi_evals: &Array>, + max_out_evals: usize, + max_selector_ctxs: usize, + max_pi_evals: usize, + witin_openings: &Array>, + fixed_openings: &Array>, + num_witin_openings: &Usize, + num_fixed_openings: &Usize, + unipoly_extrapolator: &UniPolyExtrapolator, +) { + let total_evals: Usize = builder.eval(C::N::ZERO); + let total_exprs: Usize = builder.eval(C::N::ZERO); + let num_chip_groups_counted: Usize = builder.eval(C::N::ZERO); + + for (i, (circuit_name, chip_vk)) in vk.circuit_vks.iter().enumerate() { + let chip_id: Var = builder.get(chip_indices, num_chip_groups_counted.get_var()); + builder.if_eq(chip_id, RVar::from(i)).then(|builder| { + let chip_proofs = builder.get( + &zkvm_proof_input.chip_proofs, + num_chip_groups_counted.get_var(), + ); + let gkr_circuit = chip_vk + .get_cs() + .gkr_circuit + .as_ref() + .unwrap_or_else(|| panic!("{circuit_name} missing gkr circuit")); + let layer = gkr_circuit + .layers + .first() + .unwrap_or_else(|| panic!("{circuit_name} empty gkr circuit")); + let eval_len = Usize::from(layer.n_witin + layer.n_fixed + layer.n_structural_witin); + let expr_len = Usize::from(layer.exprs.len()); + iter_zip!(builder, chip_proofs).for_each(|_, builder| { + builder.assign(&total_evals, total_evals.clone() + eval_len.clone()); + builder.assign(&total_exprs, total_exprs.clone() + expr_len.clone()); }); - builder.set(&out_evals, *out_idx, claim); - } + builder.inc(&num_chip_groups_counted); + }); + } + builder.assert_eq::>(num_chip_groups_counted, chip_indices.len()); + builder.assert_usize_eq( + zkvm_proof_input.main_constraint_proof.evals.len(), + total_evals, + ); - for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[x3_group_idx] - .1 - .iter() - .enumerate() - { - let EvalExpression::Single(out_idx) = eval_expr else { - panic!("ecc bridge x3 group must use EvalExpression::Single"); - }; - let eval = builder.get(&ecc_proof.evals, 3 + degree * 5 + idx); - let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { - point: PointVariable { - fs: x3y3_point.clone(), - }, - eval, + let ecc_bridge_samples: Array> = builder.dyn_array(proofs_len.clone()); + let proof_idx: Usize = builder.eval(C::N::ZERO); + let num_chip_groups_bridge_sampled: Usize = builder.eval(C::N::ZERO); + + for (i, (circuit_name, _chip_vk)) in vk.circuit_vks.iter().enumerate() { + let circuit_vk = &vk.circuit_vks[circuit_name]; + let chip_id: Var = + builder.get(chip_indices, num_chip_groups_bridge_sampled.get_var()); + + builder.if_eq(chip_id, RVar::from(i)).then(|builder| { + let chip_proofs = builder.get( + &zkvm_proof_input.chip_proofs, + num_chip_groups_bridge_sampled.get_var(), + ); + iter_zip!(builder, chip_proofs).for_each(|ptr_vec, builder| { + let chip_proof = builder.iter_ptr_get(&chip_proofs, ptr_vec[0]); + let zero: Ext = builder.constant(C::EF::ZERO); + builder.set(&ecc_bridge_samples, proof_idx.get_var(), zero); + if circuit_vk.get_cs().has_ecc_ops() { + transcript_observe_label(builder, challenger, b"ecc_gkr_bridge_r"); + let sample_r = challenger.sample_ext(builder); + builder.set(&ecc_bridge_samples, proof_idx.get_var(), sample_r); + builder.assert_nonzero(&chip_proof.has_ecc_proof); + } + builder.inc(&proof_idx); }); - builder.set(&out_evals, *out_idx, claim); - } + builder.inc(&num_chip_groups_bridge_sampled); + }); + } + builder.assert_eq::>(num_chip_groups_bridge_sampled, chip_indices.len()); - for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[y3_group_idx] - .1 - .iter() - .enumerate() - { - let EvalExpression::Single(out_idx) = eval_expr else { - panic!("ecc bridge y3 group must use EvalExpression::Single"); - }; - let eval = builder.get(&ecc_proof.evals, 3 + degree * 6 + idx); - let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { - point: PointVariable { - fs: x3y3_point.clone(), - }, - eval, + transcript_observe_label(builder, challenger, b"combine subset evals"); + let alpha_pows = gen_alpha_pows(builder, challenger, total_exprs.clone()); + + let sigma: Ext = builder.constant(C::EF::ZERO); + let alpha_idx: Usize = builder.eval(C::N::ZERO); + let proof_idx: Usize = builder.eval(C::N::ZERO); + let num_chip_groups_verified: Usize = builder.eval(C::N::ZERO); + + for (i, (circuit_name, chip_vk)) in vk.circuit_vks.iter().enumerate() { + let circuit_vk = &vk.circuit_vks[circuit_name]; + let chip_id: Var = builder.get(chip_indices, num_chip_groups_verified.get_var()); + + builder.if_eq(chip_id, RVar::from(i)).then(|builder| { + let chip_proofs = builder.get( + &zkvm_proof_input.chip_proofs, + num_chip_groups_verified.get_var(), + ); + let gkr_circuit = chip_vk + .get_cs() + .gkr_circuit + .as_ref() + .unwrap_or_else(|| panic!("{circuit_name} missing gkr circuit")); + let layer = gkr_circuit + .layers + .first() + .unwrap_or_else(|| panic!("{circuit_name} empty gkr circuit")); + + iter_zip!(builder, chip_proofs).for_each(|ptr_vec, builder| { + let chip_proof = builder.iter_ptr_get(&chip_proofs, ptr_vec[0]); + let out_offset: Usize = + builder.eval(proof_idx.clone() * Usize::from(max_out_evals)); + let out_end: Usize = + builder.eval(out_offset.clone() + Usize::from(max_out_evals)); + let out_evals = pending_out_evals.slice(builder, out_offset, out_end); + + if circuit_vk.get_cs().has_ecc_ops() { + let sample_r = builder.get(&ecc_bridge_samples, proof_idx.get_var()); + assign_ecc_bridge_claims( + builder, + layer, + &out_evals, + &chip_proof.ecc_proof, + sample_r, + ); + } + + let eval_and_dedup_points = + extract_claim_and_point(builder, layer, &out_evals, challenges); + builder.assert_usize_eq( + Usize::from(layer.out_sel_and_eval_exprs.len()), + eval_and_dedup_points.len(), + ); + + builder + .range(0, eval_and_dedup_points.len()) + .for_each(|idx_vec, builder| { + let ClaimAndPoint { evals, .. } = + builder.get(&eval_and_dedup_points, idx_vec[0]); + let end_idx: Usize = builder.eval(alpha_idx.clone() + evals.len()); + let alpha_slice = + alpha_pows.slice(builder, alpha_idx.clone(), end_idx.clone()); + let sub_sum = ext_dot_product(builder, &evals, &alpha_slice); + builder.assign(&sigma, sigma + sub_sum); + builder.assign(&alpha_idx, end_idx); + }); + builder.inc(&proof_idx); }); - builder.set(&out_evals, *out_idx, claim); - } + builder.inc(&num_chip_groups_verified); + }); } + builder.assert_eq::>(num_chip_groups_verified, chip_indices.len()); - builder.cycle_tracker_start("Verify GKR Circuit"); - let rt = verify_gkr_circuit( + let max_num_variables = + builder.unsafe_cast_var_to_felt(zkvm_proof_input.main_constraint_proof.proof.len()); + let (global_in_point, expected_evaluation) = iop_verifier_state_verify_with_proof_degree( builder, challenger, - num_var_with_rotation, - gkr_circuit, - &chip_proof.gkr_iop_proof, - challenges, - &circuit_pi_evals, - &out_evals, - selector_ctxs, + &sigma, + &zkvm_proof_input.main_constraint_proof.proof, + max_num_variables, unipoly_extrapolator, - poly_evaluator, ); - builder.cycle_tracker_end("Verify GKR Circuit"); + challenger_observe_exts( + builder, + challenger, + &zkvm_proof_input.main_constraint_proof.evals, + ); + + let got_claim: Ext = builder.constant(C::EF::ZERO); + let eval_idx: Usize = builder.eval(C::N::ZERO); + let alpha_idx: Usize = builder.eval(C::N::ZERO); + let proof_idx: Usize = builder.eval(C::N::ZERO); + let num_chip_groups_claimed: Usize = builder.eval(C::N::ZERO); + + for (i, (circuit_name, chip_vk)) in vk.circuit_vks.iter().enumerate() { + let circuit_vk = &vk.circuit_vks[circuit_name]; + let chip_id: Var = builder.get(chip_indices, num_chip_groups_claimed.get_var()); + + builder.if_eq(chip_id, RVar::from(i)).then(|builder| { + let chip_proofs = builder.get( + &zkvm_proof_input.chip_proofs, + num_chip_groups_claimed.get_var(), + ); + let gkr_circuit = chip_vk + .get_cs() + .gkr_circuit + .as_ref() + .unwrap_or_else(|| panic!("{circuit_name} missing gkr circuit")); + let layer = gkr_circuit + .layers + .first() + .unwrap_or_else(|| panic!("{circuit_name} empty gkr circuit")); + let eval_len = Usize::from(layer.n_witin + layer.n_fixed + layer.n_structural_witin); + let expr_len = Usize::from(layer.exprs.len()); + + iter_zip!(builder, chip_proofs).for_each(|ptr_vec, builder| { + let chip_proof = builder.iter_ptr_get(&chip_proofs, ptr_vec[0]); + let out_offset: Usize = + builder.eval(proof_idx.clone() * Usize::from(max_out_evals)); + let out_end: Usize = + builder.eval(out_offset.clone() + Usize::from(max_out_evals)); + let out_evals = pending_out_evals.slice(builder, out_offset, out_end); + let selector_offset: Usize = + builder.eval(proof_idx.clone() * Usize::from(max_selector_ctxs)); + let selector_end: Usize = + builder.eval(selector_offset.clone() + Usize::from(max_selector_ctxs)); + let selector_ctxs = + pending_selector_ctxs.slice(builder, selector_offset, selector_end); + let pi_offset: Usize = + builder.eval(proof_idx.clone() * Usize::from(max_pi_evals)); + let pi_end: Usize = + builder.eval(pi_offset.clone() + Usize::from(max_pi_evals)); + let pi_evals = pending_pi_evals.slice(builder, pi_offset, pi_end); + let num_var_with_rotation = builder.get(pending_num_vars, proof_idx.get_var()); + + if circuit_vk.get_cs().has_ecc_ops() { + let sample_r = builder.get(&ecc_bridge_samples, proof_idx.get_var()); + assign_ecc_bridge_claims( + builder, + layer, + &out_evals, + &chip_proof.ecc_proof, + sample_r, + ); + } + + let eval_and_dedup_points = + extract_claim_and_point(builder, layer, &out_evals, challenges); + let eval_end: Usize = builder.eval(eval_idx.clone() + eval_len.clone()); + let layer_evals = zkvm_proof_input.main_constraint_proof.evals.slice( + builder, + eval_idx.clone(), + eval_end.clone(), + ); + let in_point = + global_in_point.slice(builder, Usize::from(0), num_var_with_rotation.clone()); + + validate_batched_main_structural_evals( + builder, + layer, + &eval_and_dedup_points, + &selector_ctxs, + &pi_evals, + &layer_evals, + &in_point, + ); + + let alpha_end: Usize = builder.eval(alpha_idx.clone() + expr_len.clone()); + let alpha_slice = alpha_pows.slice(builder, alpha_idx.clone(), alpha_end.clone()); + let main_sumcheck_challenges_len: Usize = + builder.eval(alpha_slice.len() + Usize::from(2)); + let main_sumcheck_challenges: Array> = + builder.dyn_array(main_sumcheck_challenges_len.clone()); + let alpha = builder.get(challenges, 0); + let beta = builder.get(challenges, 1); + builder.set(&main_sumcheck_challenges, 0, alpha); + builder.set(&main_sumcheck_challenges, 1, beta); + let challenge_slice = + main_sumcheck_challenges.slice(builder, 2, main_sumcheck_challenges_len); + builder + .range(0, alpha_slice.len()) + .for_each(|idx_vec, builder| { + let alpha = builder.get(&alpha_slice, idx_vec[0]); + builder.set(&challenge_slice, idx_vec[0], alpha); + }); - (rt.fs, shard_ec_sum, prod_out_evals, logup_out_evals) + let term_claim = eval_batched_main_frontload_terms( + builder, + &layer_evals, + &pi_evals, + &main_sumcheck_challenges, + &global_in_point, + num_var_with_rotation, + layer + .main_sumcheck_expression + .as_ref() + .expect("missing main sumcheck expression"), + ); + builder.assign(&got_claim, got_claim + term_claim); + + if circuit_vk.get_cs().num_witin() > 0 { + let wit_end = Usize::from(layer.n_witin); + let wits_in_evals = layer_evals.slice(builder, Usize::from(0), wit_end); + let point_clone: Array> = builder.eval(in_point.clone()); + let witin_round: RoundOpeningVariable = builder.eval(RoundOpeningVariable { + num_var: in_point.len().get_var(), + point_and_evals: PointAndEvalsVariable { + point: PointVariable { fs: point_clone }, + evals: wits_in_evals, + }, + }); + builder.set_value(witin_openings, num_witin_openings.get_var(), witin_round); + builder.inc(num_witin_openings); + } + if circuit_vk.get_cs().num_fixed() > 0 { + let fixed_start = Usize::from(layer.n_witin); + let fixed_end: Usize = + builder.eval(fixed_start.clone() + Usize::from(layer.n_fixed)); + let fixed_in_evals = layer_evals.slice(builder, fixed_start, fixed_end); + let fixed_round: RoundOpeningVariable = builder.eval(RoundOpeningVariable { + num_var: in_point.len().get_var(), + point_and_evals: PointAndEvalsVariable { + point: PointVariable { fs: in_point }, + evals: fixed_in_evals, + }, + }); + builder.set_value(fixed_openings, num_fixed_openings.get_var(), fixed_round); + builder.inc(num_fixed_openings); + } + + builder.assign(&eval_idx, eval_end); + builder.assign(&alpha_idx, alpha_end); + builder.inc(&proof_idx); + }); + builder.inc(&num_chip_groups_claimed); + }); + } + builder.assert_eq::>(num_chip_groups_claimed, chip_indices.len()); + builder.assert_ext_eq(got_claim, expected_evaluation); +} + +fn validate_batched_main_structural_evals( + builder: &mut Builder, + layer: &Layer, + eval_and_dedup_points: &Array>, + selector_ctxs: &Array>, + pi: &Array>, + layer_evals: &Array>, + in_point: &Array>, +) { + let structural_witin_offset = layer.n_witin + layer.n_fixed; + + layer + .out_sel_and_eval_exprs + .iter() + .enumerate() + .for_each(|(idx, (sel_type, _))| { + let out_point = builder.get(eval_and_dedup_points, idx).point.fs; + let ctx = builder.get(selector_ctxs, idx); + let (wit_id, expected_eval) = + evaluate_selector(builder, sel_type, &out_point, in_point, &ctx); + let main_eval = builder.get(layer_evals, wit_id + structural_witin_offset); + builder.assert_ext_eq(main_eval, expected_eval); + }); + + let zero_const = builder.constant::>(C::EF::ZERO); + for s in &layer.structural_witins { + let id = s.id; + let witin_type = s.witin_type; + let wit_id = id as usize + structural_witin_offset; + let expected_eval: Ext = match witin_type { + StructuralWitInType::EqualDistanceSequence { + offset, + multi_factor, + descending, + .. + } => { + let offset = + builder.constant::>(C::EF::from_canonical_u32(offset)); + eval_wellform_address_vec( + builder, + offset, + multi_factor as u32, + in_point, + descending, + ) + } + StructuralWitInType::EqualDistanceDynamicSequence { + multi_factor, + descending, + offset_instance_id, + .. + } => { + let offset = builder.get(pi, offset_instance_id as usize); + eval_wellform_address_vec( + builder, + offset, + multi_factor as u32, + in_point, + descending, + ) + } + StructuralWitInType::StackedIncrementalSequence { .. } => { + eval_stacked_wellform_address_vec(builder, in_point) + } + StructuralWitInType::StackedConstantSequence { .. } => { + eval_stacked_constant(builder, in_point) + } + StructuralWitInType::InnerRepeatingIncrementalSequence { k, .. } => { + let r_slice = in_point.slice(builder, k, in_point.len()); + eval_wellform_address_vec(builder, zero_const, 1, &r_slice, false) + } + StructuralWitInType::OuterRepeatingIncrementalSequence { k, .. } => { + let r_slice = in_point.slice(builder, 0, k); + eval_wellform_address_vec(builder, zero_const, 1, &r_slice, false) + } + StructuralWitInType::Empty => continue, + }; + + let main_wit_eval = builder.get(layer_evals, wit_id); + builder.assert_ext_eq(expected_eval, main_wit_eval); + } +} + +fn eval_batched_main_frontload_terms( + builder: &mut Builder, + layer_evals: &Array>, + pi: &Array>, + challenges: &Array>, + global_in_point: &Array>, + num_var_with_rotation: Usize, + expression: &Expression, +) -> Ext { + let tail_factor: Ext = builder.constant(C::EF::ONE); + builder + .range(num_var_with_rotation, global_in_point.len()) + .for_each(|idx_vec, builder| { + let point = builder.get(global_in_point, idx_vec[0]); + builder.assign(&tail_factor, tail_factor * point); + }); + + let empty_arr: Array> = builder.dyn_array(0); + let weighted_layer_evals: Array> = builder.dyn_array(layer_evals.len()); + builder + .range(0, layer_evals.len()) + .for_each(|idx_vec, builder| { + let eval = builder.get(layer_evals, idx_vec[0]); + let weighted: Ext = builder.eval(eval * tail_factor); + builder.set(&weighted_layer_evals, idx_vec[0], weighted); + }); + + eval_ceno_expr_with_instance( + builder, + &empty_arr, + &weighted_layer_evals, + &empty_arr, + pi, + challenges, + expression, + ) } pub fn verify_gkr_circuit( diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 8b4525679..e9e1a47d4 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -116,7 +116,7 @@ fn bench_add(c: &mut Criterion) { num_instances: [num_instances, 0], has_ecc_ops: false, }; - let task = ChipTask { + let mut task = ChipTask { task_id: 0, circuit_name: AddInstruction::::name(), circuit_idx: 0, @@ -131,7 +131,7 @@ fn bench_add(c: &mut Criterion) { structural_rmm: None, }; let _ = prover - .create_chip_proof(&task, &mut transcript) + .create_chip_proof(&mut task, &mut transcript) .expect("create_proof failed"); let elapsed = instant.elapsed(); println!( diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index c887bb920..12b5f29de 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -46,6 +46,7 @@ use serde::Serialize; use std::collections::{HashMap, HashSet}; use std::{ collections::{BTreeMap, BTreeSet}, + io::Write, marker::PhantomData, ops::Range, sync::Arc, @@ -55,24 +56,6 @@ use tracing::info_span; use transcript::BasicTranscript as Transcript; use witness::next_pow2_instance_padding; -#[cfg(feature = "gpu")] -fn log_gpu_mem_pool_after_shard(label: &str, shard_id: usize) { - use gkr_iop::gpu::gpu_prover::*; - - info_span!("[ceno] log_gpu_mem_pool_after_shard").in_scope(|| { - let cuda_hal = get_cuda_hal().unwrap(); - let mem_pool = cuda_hal.inner().mem_pool(); - let used_bytes = mem_pool.get_used_size().unwrap_or(0); - let reserved_bytes = mem_pool.get_reserved_size().unwrap_or(0); - tracing::info!( - "[gpu shard end][{label}] shard_id={} used={:.2}MB reserved={:.2}MB", - shard_id, - used_bytes as f64 / (1024.0 * 1024.0), - reserved_bytes as f64 / (1024.0 * 1024.0), - ); - }); -} - // default value: 16GB VRAM, each cell 4 byte, log explosion 2 pub const DEFAULT_MAX_CELLS_PER_SHARDS: u64 = (1 << 30) * 16 / 4 / 2; pub const DEFAULT_MAX_CYCLE_PER_SHARDS: Cycle = 1 << 29; @@ -2193,9 +2176,18 @@ fn create_proofs_streaming< let transcript = Transcript::new(b"riscv"); let start = std::time::Instant::now(); - let zkvm_proof = prover - .create_proof(&shard_ctx, zkvm_witness, pi, transcript) - .expect("create_proof failed"); + let zkvm_proof = + match prover.create_proof(&shard_ctx, zkvm_witness, pi, transcript) { + Ok(proof) => proof, + Err(err) => { + eprintln!( + "create_proof failed for shard {}: {err:?}", + shard_ctx.shard_id + ); + let _ = std::io::stderr().flush(); + std::process::exit(1); + } + }; tracing::debug!( "{}th shard proof created in {:?}", shard_ctx.shard_id, @@ -2203,9 +2195,7 @@ fn create_proofs_streaming< ); #[cfg(feature = "gpu")] if crate::instructions::gpu::config::is_gpu_witgen_enabled() { - log_gpu_mem_pool_after_shard("before_release", shard_ctx.shard_id); crate::instructions::gpu::cache::release_all_shard_gpu_caches(); - log_gpu_mem_pool_after_shard("after_release", shard_ctx.shard_id); } #[cfg(feature = "gpu")] if let Some(baseline) = _witgen_mem_baseline { @@ -2254,9 +2244,18 @@ fn create_proofs_streaming< let transcript = Transcript::new(b"riscv"); let start = std::time::Instant::now(); - let zkvm_proof = prover - .create_proof(&shard_ctx, zkvm_witness, pi, transcript) - .expect("create_proof failed"); + let zkvm_proof = + match prover.create_proof(&shard_ctx, zkvm_witness, pi, transcript) { + Ok(proof) => proof, + Err(err) => { + eprintln!( + "create_proof failed for shard {}: {err:?}", + shard_ctx.shard_id + ); + let _ = std::io::stderr().flush(); + std::process::exit(1); + } + }; tracing::debug!( "{}th shard proof created in {:?}", shard_ctx.shard_id, @@ -2264,9 +2263,7 @@ fn create_proofs_streaming< ); #[cfg(feature = "gpu")] if crate::instructions::gpu::config::is_gpu_witgen_enabled() { - log_gpu_mem_pool_after_shard("before_release", shard_ctx.shard_id); crate::instructions::gpu::cache::release_all_shard_gpu_caches(); - log_gpu_mem_pool_after_shard("after_release", shard_ctx.shard_id); } #[cfg(feature = "gpu")] if let Some(baseline) = _witgen_mem_baseline { diff --git a/ceno_zkvm/src/instructions/gpu/dispatch.rs b/ceno_zkvm/src/instructions/gpu/dispatch.rs index f7cf58969..0d3e440cb 100644 --- a/ceno_zkvm/src/instructions/gpu/dispatch.rs +++ b/ceno_zkvm/src/instructions/gpu/dispatch.rs @@ -135,6 +135,7 @@ pub(crate) fn try_gpu_assign_instances>( let total_instances = step_indices.len(); if total_instances == 0 { // Empty: just return empty matrices + let num_witin = num_witin.max(1); let num_structural_witin = num_structural_witin.max(1); let raw_witin = RowMajorMatrix::::new(0, num_witin, I::padding_strategy()); let raw_structural = diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 0d6a32680..0680141fd 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -75,6 +75,15 @@ pub struct ZKVMChipProof { pub num_instances: [usize; 2], } +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct MainConstraintProof { + pub proof: SumcheckLayerProof, +} + /// each field will be interpret to (constant) polynomial #[derive(Default, Clone, Debug, Serialize, Deserialize)] pub struct PublicValues { @@ -202,6 +211,7 @@ pub struct ZKVMProof> { pub public_values: PublicValues, // each circuit may have multiple proof instances pub chip_proofs: BTreeMap>>, + pub main_constraint_proof: MainConstraintProof, pub witin_commit: >::Commitment, pub opening_proof: PCS::Proof, } @@ -210,12 +220,14 @@ impl> ZKVMProof { pub fn new( public_values: PublicValues, chip_proofs: BTreeMap>>, + main_constraint_proof: MainConstraintProof, witin_commit: >::Commitment, opening_proof: PCS::Proof, ) -> Self { Self { public_values, chip_proofs, + main_constraint_proof, witin_commit, opening_proof, } diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index b1640f7b4..a1c38b56f 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -1,10 +1,12 @@ use super::hal::{ - DeviceTransporter, MainSumcheckEvals, MainSumcheckProver, OpeningProver, ProverDevice, - RotationProver, RotationProverOutput, TowerProver, TraceCommitter, + BatchedMainConstraintProver, DeviceTransporter, MainConstraintJob, MainConstraintResult, + MainSumcheckEvals, MainSumcheckProver, OpeningProver, ProverDevice, RotationProver, + RotationProverOutput, TowerProver, TraceCommitter, }; use crate::{ error::ZKVMError, scheme::{ + MainConstraintProof, constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, hal::{DeviceProvingKey, EccQuarkProver, ProofInput, TowerProverSpec}, septic_curve::{SepticExtension, SepticPoint, SymbolicSepticExtension}, @@ -29,7 +31,9 @@ use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ Expression, ToExpr, mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, + monomial::Term, util::ceil_log2, + utils::eval_by_expr_with_instance, virtual_poly::{build_eq_x_r_vec, eq_eval}, virtual_polys::VirtualPolynomialsBuilder, }; @@ -44,7 +48,7 @@ use std::{ }; use sumcheck::{ macros::{entered_span, exit_span}, - structs::{IOPProverMessage, IOPProverState}, + structs::{IOPProof, IOPProverMessage, IOPProverState}, util::{get_challenge_pows, optimal_sumcheck_threads}, }; use transcript::Transcript; @@ -1098,6 +1102,328 @@ impl> MainSumcheckProver> + BatchedMainConstraintProver> for CpuProver> +{ + fn prove_batched_main_constraints<'a>( + &self, + jobs: Vec>>, + _pcs_data: & as ProverBackend>::PcsData, + transcript: &mut impl Transcript, + ) -> Result<(MainConstraintProof, Vec>), ZKVMError> { + struct ChipMainData<'a, E: ExtensionField> { + circuit_idx: usize, + layer: &'a gkr_iop::gkr::layer::Layer, + mle_start: usize, + num_mles: usize, + num_var_with_rotation: usize, + pi: Vec, + alpha_start: usize, + } + + if jobs.is_empty() { + return Ok(( + MainConstraintProof { + proof: gkr_iop::gkr::layer::sumcheck_layer::SumcheckLayerProof { + proof: IOPProof { proofs: vec![] }, + evals: vec![], + }, + }, + vec![], + )); + } + + let mut owned_mles = Vec::>::new(); + let mut chip_data = Vec::with_capacity(jobs.len()); + let mut total_exprs = 0usize; + let mut max_num_variables = 0usize; + let mut max_degree = 0usize; + + for job in &jobs { + let ComposedConstrainSystem { + zkvm_v1_css: cs, + gkr_circuit, + } = job.cs; + + let num_instances = job.input.num_instances(); + let log2_num_instances = job.input.log2_num_instances(); + let num_var_with_rotation = log2_num_instances + job.cs.rotation_vars().unwrap_or(0); + max_num_variables = max_num_variables.max(num_var_with_rotation); + + let Some(gkr_circuit) = gkr_circuit else { + panic!("empty gkr circuit") + }; + let first_layer = gkr_circuit.layers.first().expect("empty gkr circuit layer"); + max_degree = max_degree.max(first_layer.max_expr_degree + 1); + let group_stage_masks = first_layer_output_group_stage_masks(job.cs, gkr_circuit); + let selector_ctxs = first_layer + .out_sel_and_eval_exprs + .iter() + .zip_eq(group_stage_masks.iter()) + .map(|((selector, _), stage_mask)| { + if !stage_mask.contains(GkrOutputStageMask::TOWER) || cs.ec_final_sum.is_empty() + { + SelectorContext { + offset: 0, + num_instances, + num_vars: num_var_with_rotation, + } + } else if cs.r_selector.as_ref() == Some(selector) { + SelectorContext { + offset: 0, + num_instances: job.input.num_instances[0], + num_vars: num_var_with_rotation, + } + } else if cs.w_selector.as_ref() == Some(selector) { + SelectorContext { + offset: job.input.num_instances[0], + num_instances: job.input.num_instances[1], + num_vars: num_var_with_rotation, + } + } else { + SelectorContext { + offset: 0, + num_instances, + num_vars: num_var_with_rotation, + } + } + }) + .collect_vec(); + + let mut out_evals = + vec![PointAndEval::new(job.rt_tower.clone(), E::ZERO); gkr_circuit.n_evaluations]; + + if let Some(rotation) = job.rotation.as_ref() { + let Some([left_group_idx, right_group_idx, point_group_idx]) = + first_layer.rotation_selector_group_indices() + else { + panic!("rotation proof provided for non-rotation layer") + }; + let (left_evals, right_evals, point_evals) = + split_rotation_evals(&rotation.proof.evals); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[left_group_idx].1, + &left_evals, + &rotation.left_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[right_group_idx].1, + &right_evals, + &rotation.right_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[point_group_idx].1, + &point_evals, + &rotation.point, + ); + } + + if let Some(ecc_proof) = job.ecc_proof.as_ref() { + let Some( + [ + x_group_idx, + y_group_idx, + slope_group_idx, + x3_group_idx, + y3_group_idx, + ], + ) = first_layer.ecc_bridge_group_indices() + else { + panic!("ecc proof provided for non-ecc layer") + }; + + let sample_r = transcript.sample_and_append_vec(b"ecc_gkr_bridge_r", 1)[0]; + let claims = derive_ecc_bridge_claims(ecc_proof, sample_r, num_var_with_rotation) + .expect("invalid internal ecc bridge claims"); + + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[x_group_idx].1, + &claims.x_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[y_group_idx].1, + &claims.y_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[slope_group_idx].1, + &claims.s_evals, + &claims.s_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[x3_group_idx].1, + &claims.x3_evals, + &claims.x3y3_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[y3_group_idx].1, + &claims.y3_evals, + &claims.x3y3_point, + ); + } + + let eval_and_dedup_points = first_layer + .out_sel_and_eval_exprs + .iter() + .map(|(_, out_eval_exprs)| { + out_eval_exprs + .first() + .map(|out_eval| out_eval.evaluate(&out_evals, &job.challenges).point) + }) + .collect_vec(); + let selector_eq_pairs = first_layer + .out_sel_and_eval_exprs + .iter() + .zip(eval_and_dedup_points.iter()) + .zip(selector_ctxs.iter()) + .filter_map(|(((sel_type, _), point), selector_ctx)| { + let eq = sel_type.compute(point.as_ref()?, selector_ctx)?; + let selector_expr = match sel_type { + SelectorType::Whole(expr) + | SelectorType::Prefix(expr) + | SelectorType::OrderedSparse { + expression: expr, .. + } + | SelectorType::QuarkBinaryTreeLessThan(expr) => expr, + SelectorType::None => return None, + }; + let Expression::StructuralWitIn(wit_id, _) = selector_expr else { + panic!("selector expression must be StructuralWitIn"); + }; + let wit_id = *wit_id as usize; + assert!(wit_id < first_layer.n_structural_witin); + Some((wit_id, eq)) + }) + .collect_vec(); + let mut selector_eq_by_wit_id = vec![None; first_layer.n_structural_witin]; + for (wit_id, eq) in selector_eq_pairs { + if selector_eq_by_wit_id[wit_id].is_none() { + selector_eq_by_wit_id[wit_id] = Some(eq); + } + } + + let mle_start = owned_mles.len(); + owned_mles.extend(job.input.witness.iter().map(|mle| mle.as_ref().clone())); + owned_mles.extend(job.input.fixed.iter().map(|mle| mle.as_ref().clone())); + for (selector_eq, mle) in selector_eq_by_wit_id + .into_iter() + .zip(job.input.structural_witness.iter()) + { + owned_mles.push(selector_eq.unwrap_or_else(|| mle.as_ref().clone())); + } + let num_mles = + first_layer.n_witin + first_layer.n_fixed + first_layer.n_structural_witin; + assert_eq!(owned_mles.len() - mle_start, num_mles); + + chip_data.push(ChipMainData { + circuit_idx: job.circuit_idx, + layer: first_layer, + mle_start, + num_mles, + num_var_with_rotation, + pi: job + .input + .pi + .iter() + .map(|v| v.map_either(E::from, |v| v).into_inner()) + .collect_vec(), + alpha_start: total_exprs, + }); + total_exprs += first_layer.exprs.len(); + } + + let num_threads = optimal_sumcheck_threads(max_num_variables); + let alpha_pows = get_challenge_pows(total_exprs, transcript); + let mut builder = VirtualPolynomialsBuilder::new(num_threads, max_num_variables); + let global_mle_exprs = owned_mles + .iter() + .map(|mle| builder.lift(Either::Left(mle))) + .collect_vec(); + let mut global_expr = Expression::ZERO; + + for chip in &chip_data { + let main_sumcheck_challenges = chain!( + jobs[0].challenges.iter().copied(), + alpha_pows[chip.alpha_start..chip.alpha_start + chip.layer.exprs.len()] + .iter() + .copied() + ) + .collect_vec(); + for Term { + scalar: scalar_expr, + product, + } in chip + .layer + .main_sumcheck_expression_monomial_terms + .as_ref() + .unwrap() + { + let scalar = eval_by_expr_with_instance( + &[], + &[], + &[], + &chip.pi, + &main_sumcheck_challenges, + scalar_expr, + ); + let product_expr = product + .iter() + .map(|expr| { + let Expression::WitIn(wit_id) = expr else { + panic!("main monomial product must be converted to WitIn") + }; + global_mle_exprs[chip.mle_start + *wit_id as usize].clone() + }) + .fold(Expression::ONE, |acc, expr| acc * expr); + global_expr += Expression::Constant(scalar) * product_expr; + } + } + + let span = entered_span!("IOPProverState::prove_batched_main", profiling_4 = true); + let (proof, prover_state) = + IOPProverState::prove(builder.to_virtual_polys(&[global_expr], &[]), transcript); + let global_evals = prover_state.get_mle_flatten_final_evaluations(); + let global_rt = prover_state.collect_raw_challenges(); + transcript.append_field_element_exts(&global_evals); + exit_span!(span); + + let mut results = Vec::with_capacity(jobs.len()); + for chip in &chip_data { + let input_opening_point = global_rt[..chip.num_var_with_rotation].to_vec(); + let chip_evals = &global_evals[chip.mle_start..chip.mle_start + chip.num_mles]; + results.push(MainConstraintResult { + circuit_idx: chip.circuit_idx, + input_opening_point, + opening_evals: MainSumcheckEvals { + wits_in_evals: chip_evals[..chip.layer.n_witin].to_vec(), + fixed_in_evals: chip_evals + [chip.layer.n_witin..chip.layer.n_witin + chip.layer.n_fixed] + .to_vec(), + }, + }); + } + + Ok(( + MainConstraintProof { + proof: gkr_iop::gkr::layer::sumcheck_layer::SumcheckLayerProof { + proof, + evals: global_evals, + }, + }, + results, + )) + } +} + impl> OpeningProver> for CpuProver> { diff --git a/ceno_zkvm/src/scheme/gpu/memory.rs b/ceno_zkvm/src/scheme/gpu/memory.rs index 02617b83a..3360deff0 100644 --- a/ceno_zkvm/src/scheme/gpu/memory.rs +++ b/ceno_zkvm/src/scheme/gpu/memory.rs @@ -249,8 +249,8 @@ pub fn estimate_chip_proof_memory::new( + witness_rmm.num_instances(), + 1, + InstancePaddingStrategy::Default, + ) + } else { + witness_rmm + }; Ok(unsafe { std::mem::transmute(witness_rmm) }) }) .unwrap(); @@ -1178,6 +1201,18 @@ impl> TraceCommitter> = traces.into_values().collect(); + for (trace_idx, trace) in vec_traces.iter_mut().enumerate() { + if trace.width() == 0 { + tracing::warn!( + "[gpu] replacing zero-width witness trace at index {trace_idx} with a dummy column" + ); + *trace = witness::RowMajorMatrix::::new( + trace.num_instances(), + 1, + InstancePaddingStrategy::Default, + ); + } + } if crate::instructions::gpu::config::should_materialize_witness_on_gpu() { let span = entered_span!("[gpu] normalize_trace_backing", profiling_2 = true); @@ -1377,6 +1412,144 @@ where mles } +fn shard_ram_compact_physical_rows(col_idx: usize, num_records: usize, full_rows: usize) -> usize { + // ShardRAM witness columns are laid out as: + // 0..7 x EC coordinates + // 7..14 y EC coordinates + // 14..21 EC addition slopes + // 21..30 scalar record fields + // 30.. Poseidon2 trace columns + // + // Only the scalar record fields and Poseidon2 trace are prefix-populated + // on real record rows. EC columns also carry internal tree rows in the + // upper half, so they must keep the full logical backing. + if col_idx < 21 { full_rows } else { num_records } +} + +pub fn extract_shard_ram_witness_mles_for_trace<'a, E, PCS>( + pcs_data: & as ProverBackend>::PcsData, + trace_idx: usize, + expected_num: usize, + num_vars: usize, + num_records: usize, +) -> Vec>> +where + E: ExtensionField, + PCS: PolynomialCommitmentScheme, +{ + assert_eq!( + std::any::TypeId::of::(), + std::any::TypeId::of::(), + "GPU ShardRAM compact extraction only supports BabyBear base field", + ); + + let pcs_data_basefold: &BasefoldCommitmentWithWitnessGpu< + BB31Base, + BufferImpl, + GpuDigestLayer, + GpuMatrix<'static>, + GpuPolynomial<'static>, + > = unsafe { std::mem::transmute(pcs_data) }; + + let Some(rmms) = pcs_data_basefold.rmms.as_ref() else { + return extract_witness_mles_for_trace::( + pcs_data, + trace_idx, + expected_num, + num_vars, + ); + }; + let rmm = &rmms[trace_idx]; + assert_eq!( + rmm.width(), + expected_num, + "ShardRAM trace width mismatch: expected {}, got {}", + expected_num, + rmm.width(), + ); + + let cuda_hal = get_cuda_hal().unwrap(); + let full_rows = rmm.height(); + assert_eq!( + full_rows, + 1usize << num_vars, + "ShardRAM trace height must match logical num_vars", + ); + assert!( + num_records <= full_rows, + "ShardRAM compact rows exceed full rows: {} > {}", + num_records, + full_rows, + ); + + let mles = if rmm.device_backing_layout() == Some(DeviceMatrixLayout::ColMajor) { + let device_buffer = rmm + .device_backing_ref::>() + .unwrap_or_else(|| panic!("ShardRAM col-major device backing type mismatch")); + let elem_size = std::mem::size_of::(); + let col_stride_bytes = full_rows * elem_size; + (0..expected_num) + .map(|col_idx| { + let physical_rows = + shard_ram_compact_physical_rows(col_idx, num_records, full_rows); + let start = col_idx * col_stride_bytes; + let end = start + physical_rows * elem_size; + let view_buf = device_buffer.owned_subrange(start..end); + let view_poly = GpuPolynomial::new(view_buf, num_vars); + let poly_static: GpuPolynomial<'static> = unsafe { std::mem::transmute(view_poly) }; + let mle_static = MultilinearExtensionGpu::from_ceno_gpu_base(poly_static); + Arc::new(unsafe { + std::mem::transmute::< + MultilinearExtensionGpu<'static, E>, + MultilinearExtensionGpu<'a, E>, + >(mle_static) + }) + }) + .collect::>() + } else { + let values = rmm.values(); + (0..expected_num) + .map(|col_idx| { + let physical_rows = + shard_ram_compact_physical_rows(col_idx, num_records, full_rows); + let mut column = Vec::with_capacity(physical_rows); + column.extend((0..physical_rows).map(|row| values[row * expected_num + col_idx])); + let column_bb31: Vec = unsafe { + let mut column = std::mem::ManuallyDrop::new(column); + Vec::from_raw_parts( + column.as_mut_ptr() as *mut BB31Base, + column.len(), + column.capacity(), + ) + }; + let gpu_poly = cuda_hal + .alloc_elems_from_host(&column_bb31, None) + .map(|buffer| GpuPolynomial::new(buffer, num_vars)) + .unwrap_or_else(|err| panic!("ShardRAM compact H2D failed: {err:?}")); + let mle_static = MultilinearExtensionGpu::from_ceno_gpu_base(gpu_poly); + Arc::new(unsafe { + std::mem::transmute::< + MultilinearExtensionGpu<'static, E>, + MultilinearExtensionGpu<'a, E>, + >(mle_static) + }) + }) + .collect::>() + }; + + eprintln!( + "[ceno][shard-ram-compact-mle] trace_idx={} cols={} records={} full_rows={} compact_cols={}", + trace_idx, + expected_num, + num_records, + full_rows, + expected_num.saturating_sub(21), + ); + let _ = std::io::stderr().flush(); + + mles +} + pub fn extract_witness_mles_for_trace_rmm<'a, E>( witness_rmm: witness::RowMajorMatrix<::BaseField>, ) -> Vec>> @@ -1949,6 +2122,523 @@ impl> MainSumcheckProver> + BatchedMainConstraintProver> for GpuProver> +{ + fn prove_batched_main_constraints<'a>( + &self, + mut jobs: Vec>>, + pcs_data: & as ProverBackend>::PcsData, + transcript: &mut impl Transcript, + ) -> Result<(MainConstraintProof, Vec>), ZKVMError> { + struct ChipMainData<'a, E: ExtensionField> { + circuit_idx: usize, + layer: &'a gkr_iop::gkr::layer::Layer, + mle_start: usize, + num_mles: usize, + num_var_with_rotation: usize, + pi: Vec>, + alpha_start: usize, + } + + struct HostCommonGroup { + num_vars: usize, + term_terms: Vec, + common_mle_indices: Vec, + } + + if jobs.is_empty() { + return Ok(( + MainConstraintProof { + proof: gkr_iop::gkr::layer::sumcheck_layer::SumcheckLayerProof { + proof: IOPProof { proofs: vec![] }, + evals: vec![], + }, + }, + vec![], + )); + } + + let stream = gkr_iop::gpu::get_thread_stream(); + let cuda_hal = get_cuda_hal().map_err(hal_to_backend_error)?; + for job in jobs.iter_mut() { + let num_vars = job.input.log2_num_instances() + job.cs.rotation_vars().unwrap_or(0); + if job.input.witness.is_empty() { + if let Some(trace_idx) = job.witness_trace_idx { + job.input.witness = + info_span!("[ceno] extract_main_witness_mles").in_scope(|| { + if job.circuit_name == "ShardRamCircuit" { + extract_shard_ram_witness_mles_for_trace::( + pcs_data, + trace_idx, + job.num_witin, + num_vars, + job.input.num_instances(), + ) + } else { + extract_witness_mles_for_trace::( + pcs_data, + trace_idx, + job.num_witin, + num_vars, + ) + } + }); + } + } + if job.input.structural_witness.is_empty() { + if let Some(rmm) = job.structural_rmm.as_ref() { + let num_structural_witin = job.cs.zkvm_v1_css.num_structural_witin as usize; + job.input.structural_witness = + info_span!("[ceno] transport_main_structural_witness").in_scope(|| { + transport_structural_witness_to_gpu::( + rmm, + num_structural_witin, + num_vars, + ) + }); + } + } + } + let mut selector_eqs_by_chip = Vec::with_capacity(jobs.len()); + let mut chip_data = Vec::with_capacity(jobs.len()); + let mut total_exprs = 0usize; + let mut total_mles = 0usize; + let mut max_num_variables = 0usize; + + for job in &jobs { + let ComposedConstrainSystem { + zkvm_v1_css: cs, + gkr_circuit, + } = job.cs; + let num_instances = job.input.num_instances(); + let log2_num_instances = job.input.log2_num_instances(); + let num_var_with_rotation = log2_num_instances + job.cs.rotation_vars().unwrap_or(0); + max_num_variables = max_num_variables.max(num_var_with_rotation); + + let Some(gkr_circuit) = gkr_circuit else { + panic!("empty gkr circuit") + }; + let first_layer = gkr_circuit.layers.first().expect("empty gkr circuit layer"); + let group_stage_masks = first_layer_output_group_stage_masks(job.cs, gkr_circuit); + let selector_ctxs = first_layer + .out_sel_and_eval_exprs + .iter() + .zip_eq(group_stage_masks.iter()) + .map(|((selector, _), stage_mask)| { + if !stage_mask.contains(GkrOutputStageMask::TOWER) || cs.ec_final_sum.is_empty() + { + SelectorContext { + offset: 0, + num_instances, + num_vars: num_var_with_rotation, + } + } else if cs.r_selector.as_ref() == Some(selector) { + SelectorContext { + offset: 0, + num_instances: job.input.num_instances[0], + num_vars: num_var_with_rotation, + } + } else if cs.w_selector.as_ref() == Some(selector) { + SelectorContext { + offset: job.input.num_instances[0], + num_instances: job.input.num_instances[1], + num_vars: num_var_with_rotation, + } + } else { + SelectorContext { + offset: 0, + num_instances, + num_vars: num_var_with_rotation, + } + } + }) + .collect_vec(); + + let mut out_evals = + vec![PointAndEval::new(job.rt_tower.clone(), E::ZERO); gkr_circuit.n_evaluations]; + + if let Some(rotation) = job.rotation.as_ref() { + let Some([left_group_idx, right_group_idx, point_group_idx]) = + first_layer.rotation_selector_group_indices() + else { + panic!("rotation proof provided for non-rotation layer") + }; + let (left_evals, right_evals, point_evals) = + split_rotation_evals(&rotation.proof.evals); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[left_group_idx].1, + &left_evals, + &rotation.left_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[right_group_idx].1, + &right_evals, + &rotation.right_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[point_group_idx].1, + &point_evals, + &rotation.point, + ); + } + + if let Some(ecc_proof) = job.ecc_proof.as_ref() { + let Some( + [ + x_group_idx, + y_group_idx, + slope_group_idx, + x3_group_idx, + y3_group_idx, + ], + ) = first_layer.ecc_bridge_group_indices() + else { + panic!("ecc proof provided for non-ecc layer") + }; + let sample_r = transcript.sample_and_append_vec(b"ecc_gkr_bridge_r", 1)[0]; + let claims = derive_ecc_bridge_claims(ecc_proof, sample_r, num_var_with_rotation) + .expect("invalid internal ecc bridge claims"); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[x_group_idx].1, + &claims.x_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[y_group_idx].1, + &claims.y_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[slope_group_idx].1, + &claims.s_evals, + &claims.s_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[x3_group_idx].1, + &claims.x3_evals, + &claims.x3y3_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[y3_group_idx].1, + &claims.y3_evals, + &claims.x3y3_point, + ); + } + + let eval_and_dedup_points = first_layer + .out_sel_and_eval_exprs + .iter() + .map(|(_, out_eval_exprs)| { + out_eval_exprs + .first() + .map(|out_eval| out_eval.evaluate(&out_evals, &job.challenges).point) + }) + .collect_vec(); + let selector_eq_pairs = first_layer + .out_sel_and_eval_exprs + .iter() + .zip(eval_and_dedup_points.iter()) + .zip(selector_ctxs.iter()) + .filter_map(|(((sel_type, _), point), selector_ctx)| { + let eq = gkr_iop::gkr::layer::gpu::utils::build_eq_x_r_with_sel_gpu( + &cuda_hal, + point.as_ref()?, + selector_ctx, + sel_type, + ); + let selector_expr = match sel_type { + SelectorType::Whole(expr) + | SelectorType::Prefix(expr) + | SelectorType::OrderedSparse { + expression: expr, .. + } + | SelectorType::QuarkBinaryTreeLessThan(expr) => expr, + SelectorType::None => return None, + }; + let Expression::StructuralWitIn(wit_id, _) = selector_expr else { + panic!("selector expression must be StructuralWitIn"); + }; + Some((*wit_id as usize, eq)) + }) + .collect_vec(); + let mut selector_eq_by_wit_id = vec![None; first_layer.n_structural_witin]; + for (wit_id, eq) in selector_eq_pairs { + if selector_eq_by_wit_id[wit_id].is_none() { + selector_eq_by_wit_id[wit_id] = Some(eq); + } + } + selector_eqs_by_chip.push(selector_eq_by_wit_id); + + let num_mles = + first_layer.n_witin + first_layer.n_fixed + first_layer.n_structural_witin; + chip_data.push(ChipMainData { + circuit_idx: job.circuit_idx, + layer: first_layer, + mle_start: total_mles, + num_mles, + num_var_with_rotation, + pi: job.input.pi.clone(), + alpha_start: total_exprs, + }); + total_mles += num_mles; + total_exprs += first_layer.exprs.len(); + } + let mut all_witins_gpu = Vec::with_capacity(total_mles); + for ((job, chip), selector_eq_by_wit_id) in jobs + .iter() + .zip(chip_data.iter()) + .zip(selector_eqs_by_chip.iter()) + { + all_witins_gpu.extend(job.input.witness.iter().map(|mle| mle.as_ref())); + all_witins_gpu.extend(job.input.fixed.iter().map(|mle| mle.as_ref())); + for (selector_eq, mle) in selector_eq_by_wit_id + .iter() + .zip(job.input.structural_witness.iter()) + { + if let Some(eq) = selector_eq.as_ref() { + all_witins_gpu.push(eq); + } else { + all_witins_gpu.push(mle.as_ref()); + } + } + assert_eq!( + all_witins_gpu.len(), + chip.mle_start + chip.num_mles, + "invalid gpu main witness layout" + ); + } + let alpha_pows = get_challenge_pows(total_exprs, transcript); + let mut term_coefficients = Vec::new(); + let mut mle_indices_per_term = Vec::new(); + let mut mle_size_info = Vec::new(); + let mut common_groups = Vec::new(); + for chip in &chip_data { + let main_sumcheck_challenges = chain!( + jobs[0].challenges.iter().copied(), + alpha_pows[chip.alpha_start..chip.alpha_start + chip.layer.exprs.len()] + .iter() + .copied() + ) + .collect_vec(); + let common_plan = chip.layer.main_sumcheck_expression_common_factored.as_ref(); + let monomial_terms = match ( + common_plan, + chip.layer + .main_sumcheck_expression_monomial_terms_excluded_shared + .as_ref(), + ) { + (Some(_), Some(residual_terms)) => residual_terms, + (Some(_), None) => { + panic!("common factoring plan present without residual monomials") + } + (None, Some(terms)) => terms, + (None, None) => chip + .layer + .main_sumcheck_expression_monomial_terms + .as_ref() + .unwrap(), + }; + let term_start = term_coefficients.len(); + for term in monomial_terms { + let scalar = + eval_by_expr_constant(&chip.pi, &main_sumcheck_challenges, &term.scalar) + .map_either(E::from, |v| v) + .into_inner(); + term_coefficients.push(scalar); + let indices = term + .product + .iter() + .map(|expr| { + let Expression::WitIn(wit_id) = expr else { + panic!("main monomial product must be converted to WitIn") + }; + chip.mle_start + *wit_id as usize + }) + .collect_vec(); + let first_idx = indices.first().copied(); + mle_indices_per_term.push(indices); + if let Some(first_idx) = first_idx { + let num_vars = all_witins_gpu[first_idx].mle.num_vars(); + mle_size_info.push((num_vars, num_vars)); + } else { + mle_size_info.push((0, 0)); + } + } + let mut covered_terms = vec![false; monomial_terms.len()]; + if let Some(common_plan) = common_plan { + for group in &common_plan.groups { + assert!( + !group.term_indices.is_empty(), + "common term group must include at least one term" + ); + let mut group_term_terms = Vec::with_capacity(group.term_indices.len()); + for &term_idx in &group.term_indices { + assert!( + term_idx < monomial_terms.len(), + "common term index {} out of range (terms={})", + term_idx, + monomial_terms.len() + ); + covered_terms[term_idx] = true; + group_term_terms.push( + u32::try_from(term_start + term_idx) + .expect("term index exceeds supported range for GPU plan"), + ); + } + + let mut group_mle_indices = Vec::with_capacity(group.witness_indices.len()); + for &wit_idx in &group.witness_indices { + assert!( + wit_idx < chip.num_mles, + "common witness index {} out of range (mles={})", + wit_idx, + chip.num_mles + ); + group_mle_indices.push( + u32::try_from(chip.mle_start + wit_idx) + .expect("witness index exceeds supported range for GPU plan"), + ); + } + common_groups.push(HostCommonGroup { + num_vars: chip.num_var_with_rotation, + term_terms: group_term_terms, + common_mle_indices: group_mle_indices, + }); + } + } + let mut uncovered_terms = Vec::new(); + for (term_idx, covered) in covered_terms.iter().copied().enumerate() { + if !covered { + uncovered_terms.push( + u32::try_from(term_start + term_idx) + .expect("term index exceeds supported range for GPU plan"), + ); + } + } + if !uncovered_terms.is_empty() { + common_groups.push(HostCommonGroup { + num_vars: chip.num_var_with_rotation, + term_terms: uncovered_terms, + common_mle_indices: Vec::new(), + }); + } + } + + common_groups.sort_by(|lhs, rhs| rhs.num_vars.cmp(&lhs.num_vars)); + + let mut common_term_offsets = Vec::with_capacity(common_groups.len() + 1); + let mut common_term_terms = Vec::new(); + let mut common_mle_offsets = Vec::with_capacity(common_groups.len() + 1); + let mut common_mle_indices = Vec::new(); + common_term_offsets.push(0); + common_mle_offsets.push(0); + for group in &common_groups { + common_term_terms.extend(group.term_terms.iter().copied()); + common_term_offsets.push(common_term_terms.len() as u32); + common_mle_indices.extend(group.common_mle_indices.iter().copied()); + common_mle_offsets.push(common_mle_indices.len() as u32); + } + + let max_degree = common_groups + .iter() + .map(|group| { + let common_len = group.common_mle_indices.len(); + let max_residual_len = group + .term_terms + .iter() + .map(|&term_idx| mle_indices_per_term[term_idx as usize].len()) + .max() + .unwrap_or(0); + common_len + max_residual_len + }) + .max() + .unwrap_or(0); + let basic_transcript = expect_basic_transcript(transcript); + let common_scalar_offsets = vec![0u32; common_mle_offsets.len()]; + let common_term_plan = CommonTermPlan { + term_offsets: common_term_offsets, + term_terms: common_term_terms, + common_mle_offsets, + common_mle_indices, + common_scalar_offsets, + common_scalar_indices: vec![], + }; + let term_coefficients_gl64: Vec = + unsafe { std::mem::transmute(term_coefficients) }; + let all_witins_gpu_gl64: Vec<&MultilinearExtensionGpu> = + unsafe { std::mem::transmute(all_witins_gpu) }; + let all_witins_gpu_type_gl64 = all_witins_gpu_gl64.iter().map(|mle| &mle.mle).collect_vec(); + let (proof_gpu, evals_gpu, challenges_gpu) = cuda_hal + .sumcheck + .prove_generic_sumcheck_gpu_v2( + cuda_hal.as_ref(), + all_witins_gpu_type_gl64, + &mle_size_info, + &term_coefficients_gl64, + &mle_indices_per_term, + max_num_variables, + max_degree, + Some(&common_term_plan), + basic_transcript, + stream.as_ref(), + ) + .map_err(|e| hal_to_backend_error(format!("GPU main sumcheck failed: {e:?}")))?; + let proof: IOPProof = unsafe { std::mem::transmute(proof_gpu) }; + let evals_gpu_e: Vec> = unsafe { std::mem::transmute(evals_gpu) }; + let global_evals = evals_gpu_e.into_iter().flatten().collect_vec(); + let global_rt: Point = unsafe { + std::mem::transmute::, Vec>( + challenges_gpu.iter().map(|c| c.elements).collect(), + ) + }; + + transcript.append_field_element_exts(&global_evals); + + let mut results = Vec::with_capacity(chip_data.len()); + for chip in &chip_data { + let input_opening_point = + frontload_input_opening_point(&global_rt, chip.num_var_with_rotation); + let chip_evals = &global_evals[chip.mle_start..chip.mle_start + chip.num_mles]; + results.push(MainConstraintResult { + circuit_idx: chip.circuit_idx, + input_opening_point, + opening_evals: MainSumcheckEvals { + wits_in_evals: chip_evals[..chip.layer.n_witin].to_vec(), + fixed_in_evals: chip_evals + [chip.layer.n_witin..chip.layer.n_witin + chip.layer.n_fixed] + .to_vec(), + }, + }); + } + + Ok(( + MainConstraintProof { + proof: gkr_iop::gkr::layer::sumcheck_layer::SumcheckLayerProof { + proof, + evals: global_evals, + }, + }, + results, + )) + } +} + +fn frontload_input_opening_point( + global_rt: &[E], + num_var_with_rotation: usize, +) -> Point { + global_rt[..num_var_with_rotation].to_vec() +} + impl> RotationProver> for GpuProver> { @@ -2090,6 +2780,143 @@ impl> OpeningProver( + prover: &GpuProver>, + witness_data: as ProverBackend>::PcsData, + fixed_data: Option as ProverBackend>::PcsData>>, + replayable_traces: &[(usize, crate::structs::GpuReplayPlan)], + points: Vec>, + mut evals: Vec>>, + transcript: &mut (impl Transcript + 'static), +) -> PCS::Proof +where + E: ExtensionField, + PCS: PolynomialCommitmentScheme, +{ + if std::any::TypeId::of::() != std::any::TypeId::of::() { + panic!("GPU backend only supports BabyBear base field"); + } + + let mut rounds = vec![]; + rounds.push((&witness_data, { + evals + .iter_mut() + .zip(&points) + .filter_map(|(evals, point)| { + let witin_evals = evals.remove(0); + if !witin_evals.is_empty() { + Some((point.clone(), witin_evals)) + } else { + None + } + }) + .collect_vec() + })); + if let Some(fixed_data) = fixed_data.as_ref().map(|f| f.as_ref()) { + rounds.push((fixed_data, { + evals + .iter_mut() + .zip(points.iter().cloned()) + .filter_map(|(evals, point)| { + if !evals.is_empty() && !evals[0].is_empty() { + Some((point.clone(), evals.remove(0))) + } else { + None + } + }) + .collect_vec() + })); + } + + let prover_param = &prover.backend.pp; + let pp_gl64: &mpcs::basefold::structure::BasefoldProverParams = + unsafe { std::mem::transmute(prover_param) }; + let rounds_gl64: Vec<_> = rounds + .iter() + .map(|(commitment, point_eval_pairs)| { + let commitment_gl64: &BasefoldCommitmentWithWitnessGpu< + BB31Base, + BufferImpl, + GpuDigestLayer, + GpuMatrix<'static>, + GpuPolynomial<'static>, + > = unsafe { std::mem::transmute(*commitment) }; + let point_eval_pairs_gl64: Vec<_> = point_eval_pairs + .iter() + .map(|(point, evals)| { + let point_gl64: &Vec = unsafe { std::mem::transmute(point) }; + let evals_gl64: &Vec = unsafe { std::mem::transmute(evals) }; + (point_gl64.clone(), evals_gl64.clone()) + }) + .collect(); + (commitment_gl64, point_eval_pairs_gl64) + }) + .collect(); + + if std::any::TypeId::of::() != std::any::TypeId::of::() { + panic!("GPU backend only supports BabyBear field extension"); + } + + let transcript_any = transcript as &mut dyn std::any::Any; + let basic_transcript = transcript_any + .downcast_mut::>() + .expect("Type should match"); + + let cuda_hal = get_cuda_hal().unwrap(); + let gpu_proof_basefold = cuda_hal + .basefold + .batch_open_with_trace_materializer( + &cuda_hal, + pp_gl64, + rounds_gl64, + basic_transcript, + |round_idx, trace_idx| { + if round_idx != 0 { + return Ok(None); + } + let Some((_, replay_plan)) = replayable_traces + .iter() + .find(|(replay_trace_idx, _)| *replay_trace_idx == trace_idx) + else { + return Ok(None); + }; + let witness_rmm = info_span!( + "[ceno] replay_witness_materialize", + phase = "pcs_opening", + round_idx, + trace_idx, + kind = ?replay_plan.kind, + rows = replay_plan.trace_height, + num_witin = replay_plan.num_witin, + steps = replay_plan.step_indices.len(), + ) + .in_scope(|| replay_plan.replay_witness()) + .map_err(|err| { + ceno_gpu::HalError::InvalidInput(format!( + "failed to replay trace {trace_idx} for PCS opening: {err:?}" + )) + })?; + if witness_rmm.height() != replay_plan.trace_height { + return Err(ceno_gpu::HalError::InvalidInput(format!( + "replayed trace {trace_idx} height changed before PCS opening: expected {}, got {}", + replay_plan.trace_height, + witness_rmm.height(), + ))); + } + let witness_rmm_bb31: witness::RowMajorMatrix = + unsafe { std::mem::transmute(witness_rmm) }; + Ok(Some(witness_rmm_bb31)) + }, + ) + .unwrap(); + + let gpu_proof: PCS::Proof = unsafe { std::mem::transmute_copy(&gpu_proof_basefold) }; + std::mem::forget(gpu_proof_basefold); + drop(rounds); + drop(witness_data); + gpu_proof +} + impl> DeviceTransporter> for GpuProver> { @@ -2182,41 +3009,44 @@ impl> pcs_data: & as gkr_iop::hal::ProverBackend>::PcsData, ) { if let Some(replay_plan) = task.gpu_replay_plan.as_ref() { - let cuda_hal = get_cuda_hal().unwrap(); - let gpu_mem_tracker = init_gpu_mem_tracker(&cuda_hal, "replay_gpu_witness_from_raw"); let num_vars = task.input.log2_num_instances() + task.pk.get_cs().rotation_vars().unwrap_or(0); - let estimated_replay_bytes = - estimate_replay_materialization_bytes_for_plan(replay_plan, num_vars); - tracing::info!( - "[gpu] replaying witness from raw: circuit={}, estimated={:.2}MB", - task.circuit_name, - estimated_replay_bytes as f64 / (1024.0 * 1024.0), - ); - task.input.witness = if let Some(trace_idx) = task.witness_trace_idx { - check_gpu_mem_estimation_with_context( - gpu_mem_tracker, - 0, - Some(task.circuit_name.as_str()), - ); - info_span!("[ceno] extract_witness_mles").in_scope(|| { - extract_witness_mles_for_trace::( - pcs_data, - trace_idx, - task.num_witin, - num_vars, - ) - }) - } else { - let witness_rmm = replay_plan.replay_witness().expect("GPU raw replay failed"); - check_gpu_mem_estimation_with_context( - gpu_mem_tracker, - estimated_replay_bytes, - Some(task.circuit_name.as_str()), + if task.num_witin > 0 { + let cuda_hal = get_cuda_hal().unwrap(); + let gpu_mem_tracker = + init_gpu_mem_tracker(&cuda_hal, "replay_gpu_witness_from_raw"); + let estimated_replay_bytes = + estimate_replay_materialization_bytes_for_plan(replay_plan, num_vars); + tracing::info!( + "[gpu] replaying witness from raw: circuit={}, estimated={:.2}MB", + task.circuit_name, + estimated_replay_bytes as f64 / (1024.0 * 1024.0), ); - info_span!("[ceno] replay_gpu_witness_from_raw") - .in_scope(|| extract_witness_mles_for_trace_rmm::(witness_rmm)) - }; + task.input.witness = if let Some(trace_idx) = task.witness_trace_idx { + check_gpu_mem_estimation_with_context( + gpu_mem_tracker, + 0, + Some(task.circuit_name.as_str()), + ); + info_span!("[ceno] extract_witness_mles").in_scope(|| { + extract_witness_mles_for_trace::( + pcs_data, + trace_idx, + task.num_witin, + num_vars, + ) + }) + } else { + let witness_rmm = replay_plan.replay_witness().expect("GPU raw replay failed"); + check_gpu_mem_estimation_with_context( + gpu_mem_tracker, + estimated_replay_bytes, + Some(task.circuit_name.as_str()), + ); + info_span!("[ceno] replay_gpu_witness_from_raw") + .in_scope(|| extract_witness_mles_for_trace_rmm::(witness_rmm)) + }; + } if let Some(rmm) = task.structural_rmm.as_ref() { task.input.structural_witness = info_span!("[ceno] transport_structural_witness") .in_scope(|| { diff --git a/ceno_zkvm/src/scheme/hal.rs b/ceno_zkvm/src/scheme/hal.rs index 65fe06f2d..9d8acb946 100644 --- a/ceno_zkvm/src/scheme/hal.rs +++ b/ceno_zkvm/src/scheme/hal.rs @@ -20,6 +20,7 @@ pub trait ProverDevice: TraceCommitter + TowerProver + MainSumcheckProver + + BatchedMainConstraintProver + OpeningProver + DeviceTransporter + ProtocolWitnessGeneratorProver @@ -33,6 +34,23 @@ where fn get_pb(&self) -> &PB; } +pub trait BatchedMainConstraintProver { + fn prove_batched_main_constraints<'a>( + &self, + jobs: Vec>, + pcs_data: &PB::PcsData, + transcript: &mut impl Transcript, + ) -> BatchedMainConstraintResult; +} + +pub type BatchedMainConstraintResult = Result< + ( + crate::scheme::MainConstraintProof, + Vec>, + ), + ZKVMError, +>; + /// Prepare a chip task's input for proving. /// CPU: no-op (input already fully populated during task building). /// GPU: deferred witness extraction + structural witness transport. @@ -54,6 +72,19 @@ pub struct ProofInput<'a, PB: ProverBackend> { pub has_ecc_ops: bool, } +impl<'a, PB: ProverBackend> Clone for ProofInput<'a, PB> { + fn clone(&self) -> Self { + Self { + witness: self.witness.clone(), + structural_witness: self.structural_witness.clone(), + fixed: self.fixed.clone(), + pi: self.pi.clone(), + num_instances: self.num_instances, + has_ecc_ops: self.has_ecc_ops, + } + } +} + impl<'a, PB: ProverBackend> ProofInput<'a, PB> { pub fn num_instances(&self) -> usize { self.num_instances.iter().sum() @@ -154,6 +185,26 @@ pub struct MainSumcheckEvals { pub fixed_in_evals: Vec, } +pub struct MainConstraintJob<'a, PB: ProverBackend> { + pub circuit_name: String, + pub circuit_idx: usize, + pub input: ProofInput<'static, PB>, + pub witness_trace_idx: Option, + pub num_witin: usize, + pub structural_rmm: Option::BaseField>>, + pub rt_tower: Point, + pub rotation: Option>, + pub ecc_proof: Option>, + pub challenges: [PB::E; 2], + pub cs: &'a ComposedConstrainSystem, +} + +pub struct MainConstraintResult { + pub circuit_idx: usize, + pub input_opening_point: Point, + pub opening_evals: MainSumcheckEvals, +} + #[derive(Clone)] pub struct RotationProverOutput { pub proof: SumcheckLayerProof, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index c0e4e42e6..0cbacff97 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -14,7 +14,7 @@ use crate::scheme::gpu::estimate_chip_proof_memory; #[cfg(feature = "gpu")] use crate::scheme::scheduler::get_chip_proving_mode; use crate::scheme::{ - hal::MainSumcheckEvals, + hal::{MainConstraintJob, MainConstraintResult, MainSumcheckEvals}, scheduler::{ChipScheduler, ChipTask, ChipTaskResult}, }; #[cfg(feature = "gpu")] @@ -47,7 +47,10 @@ use crate::{ structs::{TowerProofs, ZKVMProvingKey, ZKVMWitnesses}, }; -type CreateTableProof = (ZKVMChipProof, MainSumcheckEvals, Point); +type CreateTableProof<'a, PB> = ( + ZKVMChipProof<::E>, + MainConstraintJob<'a, PB>, +); pub type ZkVMCpuProver = ZKVMProver, CpuProver>>; @@ -247,23 +250,23 @@ impl< None }; - #[cfg(feature = "gpu")] - if use_deferred_gpu_commit { - if let Some(plan) = gpu_replay_plan.clone() { - deferred_gpu_traces - .insert(i, crate::scheme::gpu::DeferredGpuTrace::Replay(plan)); - } else if witness_rmm.num_instances() > 0 { - deferred_gpu_traces - .insert(i, crate::scheme::gpu::DeferredGpuTrace::Eager(witness_rmm)); - } - } else if witness_rmm.num_instances() > 0 { - wits_rmms.insert(i, witness_rmm); - } - - #[cfg(not(feature = "gpu"))] - if witness_rmm.num_instances() > 0 { - wits_rmms.insert(i, witness_rmm); - } + #[cfg(feature = "gpu")] + if use_deferred_gpu_commit { + if let Some(plan) = gpu_replay_plan.clone().filter(|plan| plan.num_witin > 0) { + deferred_gpu_traces + .insert(i, crate::scheme::gpu::DeferredGpuTrace::Replay(plan)); + } else if witness_rmm.num_instances() > 0 && witness_rmm.width > 0 { + deferred_gpu_traces + .insert(i, crate::scheme::gpu::DeferredGpuTrace::Eager(witness_rmm)); + } + } else if witness_rmm.num_instances() > 0 && witness_rmm.width > 0 { + wits_rmms.insert(i, witness_rmm); + } + + #[cfg(not(feature = "gpu"))] + if witness_rmm.num_instances() > 0 && witness_rmm.width > 0 { + wits_rmms.insert(i, witness_rmm); + } structural_rmms.push(structural_witness_rmm); #[cfg(feature = "gpu")] witness_trace_rows.push(trace_rows_for_estimate); @@ -541,7 +544,7 @@ impl< // Phase 3: Collect results let collect_results_span = entered_span!("collect_chip_results", profiling_1 = true); - let (chip_proofs, points, evaluations) = Self::collect_chip_results(results); + let (chip_proofs, main_constraint_jobs) = Self::collect_chip_results(results); exit_span!(collect_results_span); exit_span!(main_proofs_span); @@ -550,31 +553,95 @@ impl< transcript.append_field_element_ext(&sample); } - // batch opening pcs - // generate static info from prover key for expected num variable - let pcs_opening = entered_span!("pcs_opening", profiling_1 = true); #[cfg(feature = "gpu")] if needs_replay_restore { let replay_cache = current_replay_cache_stats(); tracing::info!( - "[gpu replay cache][before_restore_pcs] shard_steps={:.2}MB shard_meta={:.2}MB shared_side_effect={:.2}MB total={:.2}MB", + "[gpu replay cache][before_restore_main] shard_steps={:.2}MB shard_meta={:.2}MB shared_side_effect={:.2}MB total={:.2}MB", replay_cache.shard_steps_bytes as f64 / (1024.0 * 1024.0), replay_cache.shard_meta_bytes as f64 / (1024.0 * 1024.0), replay_cache.shared_side_effect_bytes as f64 / (1024.0 * 1024.0), replay_cache.total_bytes() as f64 / (1024.0 * 1024.0), ); - crate::scheme::gpu::log_gpu_device_state("before_restore_pcs"); + crate::scheme::gpu::log_gpu_device_state("before_restore_main"); let gpu_witness_data: &mut as ProverBackend>::PcsData = unsafe { std::mem::transmute(&mut witness_data) }; crate::scheme::gpu::restore_replayable_trace_device_backing::( gpu_witness_data, &replayable_traces, )?; - crate::scheme::gpu::log_gpu_device_state("after_restore_pcs"); + crate::scheme::gpu::log_gpu_device_state("after_restore_main"); + } + + let main_constraints_span = + entered_span!("prove_batched_main_constraints", profiling_1 = true); + let (main_constraint_proof, main_constraint_results) = + info_span!("[ceno] prove_batched_main_constraints").in_scope(|| { + self.device.prove_batched_main_constraints( + main_constraint_jobs, + &witness_data, + &mut transcript, + ) + })?; + let (points, evaluations) = + Self::collect_main_constraint_results(main_constraint_results); + exit_span!(main_constraints_span); + + #[cfg(feature = "gpu")] + if needs_replay_restore { + crate::scheme::gpu::log_gpu_device_state("before_clear_main_backing"); + let gpu_witness_data: &mut as ProverBackend>::PcsData = + unsafe { std::mem::transmute(&mut witness_data) }; + crate::scheme::gpu::clear_replayable_trace_device_backing::( + gpu_witness_data, + &replayable_traces, + ); + crate::scheme::gpu::log_gpu_device_state("after_clear_main_backing"); + } + + // batch opening pcs + // generate static info from prover key for expected num variable + let pcs_opening = entered_span!("pcs_opening", profiling_1 = true); + #[cfg(feature = "gpu")] + if needs_replay_restore { + let replay_cache = current_replay_cache_stats(); + tracing::info!( + "[gpu replay cache][before_pcs_opening] shard_steps={:.2}MB shard_meta={:.2}MB shared_side_effect={:.2}MB total={:.2}MB", + replay_cache.shard_steps_bytes as f64 / (1024.0 * 1024.0), + replay_cache.shard_meta_bytes as f64 / (1024.0 * 1024.0), + replay_cache.shared_side_effect_bytes as f64 / (1024.0 * 1024.0), + replay_cache.total_bytes() as f64 / (1024.0 * 1024.0), + ); + crate::scheme::gpu::log_gpu_device_state("before_pcs_opening"); } let mpcs_opening_proof = info_span!("[ceno] pcs_opening").in_scope(|| { #[cfg(feature = "gpu")] { + if needs_replay_restore { + let gpu_device: &gkr_iop::gpu::GpuProver< + gkr_iop::gpu::GpuBackend, + > = unsafe { std::mem::transmute(&self.device) }; + let gpu_witness_data: as ProverBackend>::PcsData = + unsafe { std::mem::transmute_copy(&witness_data) }; + std::mem::forget(witness_data); + let fixed_data = self + .get_device_proving_key(shard_ctx) + .map(|dpk| dpk.pcs_data.clone()); + let gpu_fixed_data: Option< + std::sync::Arc< + as ProverBackend>::PcsData, + >, + > = unsafe { std::mem::transmute(fixed_data) }; + return crate::scheme::gpu::open_with_incremental_replay::( + gpu_device, + gpu_witness_data, + gpu_fixed_data, + &replayable_traces, + points, + evaluations, + &mut transcript, + ); + } } self.device.open( witness_data, @@ -587,7 +654,13 @@ impl< }); exit_span!(pcs_opening); - let vm_proof = ZKVMProof::new(pi, chip_proofs, witin_commit, mpcs_opening_proof); + let vm_proof = ZKVMProof::new( + pi, + chip_proofs, + main_constraint_proof, + witin_commit, + mpcs_opening_proof, + ); Ok(vm_proof) }) @@ -604,13 +677,14 @@ impl< tasks: Vec>, transcript: &T, witness_data: &PB::PcsData, - ) -> Result<(Vec>, Vec), ZKVMError> { + ) -> Result<(Vec>, Vec), ZKVMError> { let scheduler = ChipScheduler::new(); #[cfg(feature = "gpu")] { - if std::any::TypeId::of::() - == std::any::TypeId::of::>() + if false + && std::any::TypeId::of::() + == std::any::TypeId::of::>() { let gpu_witness_data: & as gkr_iop::hal::ProverBackend>::PcsData = unsafe { std::mem::transmute(witness_data) }; @@ -654,11 +728,12 @@ impl< proof, opening_evals, input_opening_point, + main_constraint_job: None, has_witness_or_fixed: task.has_witness_or_fixed, }) }; - if ChipScheduler::is_concurrent_mode() { + if false && ChipScheduler::is_concurrent_mode() { // SAFETY: pcs_data is only read (via get_trace) during concurrent execution. use crate::scheme::utils::SyncRef; let gpu_wd = SyncRef(gpu_witness_data); @@ -704,15 +779,18 @@ impl< )); } - let (proof, opening_evals, input_opening_point) = - self.create_chip_proof(&task, transcript)?; + let (proof, main_constraint_job) = self.create_chip_proof(&mut task, transcript)?; Ok(ChipTaskResult { task_id: task.task_id, circuit_idx: task.circuit_idx, proof, - opening_evals, - input_opening_point, + opening_evals: MainSumcheckEvals { + wits_in_evals: vec![], + fixed_in_evals: vec![], + }, + input_opening_point: vec![], + main_constraint_job: Some(main_constraint_job), has_witness_or_fixed: task.has_witness_or_fixed, }) }) @@ -724,11 +802,11 @@ impl< /// into a single tower tree, and then feed these trees into tower prover. #[tracing::instrument(skip_all, name = "create_chip_proof", fields(table_name=%task.circuit_name, profiling_2 ), level = "trace")] - pub fn create_chip_proof( + pub fn create_chip_proof<'a>( &self, - task: &ChipTask<'_, PB>, + task: &mut ChipTask<'a, PB>, transcript: &mut impl Transcript, - ) -> Result, ZKVMError> { + ) -> Result, ZKVMError> { let circuit_pk = task.pk; let input = &task.input; let challenges = &task.challenges; @@ -814,68 +892,49 @@ impl< })?; exit_span!(span); - // 1. prove the main constraints among witness polynomials - // 2. prove the relation between last layer in the tower and read/write/logup records - let span = entered_span!("prove_main_constraints", profiling_2 = true); #[cfg(feature = "gpu")] - if task.gpu_replay_plan.as_ref().is_some_and(|plan| { - matches!( - plan.kind, - crate::instructions::gpu::dispatch::GpuWitgenKind::Keccak - ) - }) { - crate::scheme::gpu::log_gpu_pool_usage(&format!( - "{}:before_prove_main", - task.circuit_name - )); - } - let (input_opening_point, evals, main_sumcheck_proofs, gkr_iop_proof) = - info_span!("[ceno] prove_main_constraints").in_scope(|| { - self.device.prove_main_constraints( - rt_tower, - rotation.clone(), - ecc_proof.as_ref(), - input, - cs, - challenges, - transcript, - ) - })?; + let main_input = { + let mut input = input.clone(); + if std::any::TypeId::of::() + == std::any::TypeId::of::>() + { + input.witness.clear(); + input.structural_witness.clear(); + } + input + }; + #[cfg(not(feature = "gpu"))] + let main_input = input.clone(); #[cfg(feature = "gpu")] - if task.gpu_replay_plan.as_ref().is_some_and(|plan| { - matches!( - plan.kind, - crate::instructions::gpu::dispatch::GpuWitgenKind::Keccak - ) - }) { - crate::scheme::gpu::log_gpu_pool_usage(&format!( - "{}:after_prove_main", - task.circuit_name - )); - } - let MainSumcheckEvals { - wits_in_evals, - fixed_in_evals, - } = evals; - exit_span!(span); + let structural_rmm = task.structural_rmm.take(); + #[cfg(not(feature = "gpu"))] + let structural_rmm = None; Ok(( ZKVMChipProof { r_out_evals, w_out_evals, lk_out_evals, - main_sumcheck_proofs, - gkr_iop_proof, - rotation_proof: rotation.map(|r| r.proof), + main_sumcheck_proofs: None, + gkr_iop_proof: None, + rotation_proof: rotation.clone().map(|r| r.proof), tower_proof, - ecc_proof, + ecc_proof: ecc_proof.clone(), num_instances: input.num_instances, }, - MainSumcheckEvals { - wits_in_evals, - fixed_in_evals, + MainConstraintJob { + circuit_name: task.circuit_name.clone(), + circuit_idx: task.circuit_idx, + input: main_input, + witness_trace_idx: task.witness_trace_idx, + num_witin: task.num_witin, + structural_rmm, + rt_tower, + rotation, + ecc_proof, + challenges: *challenges, + cs, }, - input_opening_point, )) } @@ -1086,16 +1145,14 @@ impl< /// Phase 3: Collect chip proof results into proof components. #[allow(clippy::type_complexity)] - fn collect_chip_results( - results: Vec>, + fn collect_chip_results<'a>( + results: Vec>, ) -> ( BTreeMap>>, - Vec>, - Vec>>, + Vec>, ) { let mut chip_proofs = BTreeMap::new(); - let mut points = Vec::new(); - let mut evaluations = Vec::new(); + let mut main_constraint_jobs = Vec::new(); for result in results { tracing::trace!( @@ -1104,12 +1161,8 @@ impl< result.task_id ); - if result.has_witness_or_fixed { - points.push(result.input_opening_point); - evaluations.push(vec![ - result.opening_evals.wits_in_evals, - result.opening_evals.fixed_in_evals, - ]); + if let Some(job) = result.main_constraint_job { + main_constraint_jobs.push(job); } chip_proofs .entry(result.circuit_idx) @@ -1117,7 +1170,26 @@ impl< .push(result.proof); } - (chip_proofs, points, evaluations) + (chip_proofs, main_constraint_jobs) + } + + fn collect_main_constraint_results( + results: Vec>, + ) -> (Vec>, Vec>>) { + let mut points = Vec::new(); + let mut evaluations = Vec::new(); + for result in results { + if !result.opening_evals.wits_in_evals.is_empty() + || !result.opening_evals.fixed_in_evals.is_empty() + { + points.push(result.input_opening_point); + evaluations.push(vec![ + result.opening_evals.wits_in_evals, + result.opening_evals.fixed_in_evals, + ]); + } + } + (points, evaluations) } } @@ -1139,7 +1211,7 @@ pub fn create_chip_proof_gpu_impl<'a, E, PCS>( #[cfg(feature = "gpu")] gpu_replay_plan: Option>, num_witin: usize, structural_rmm: Option::BaseField>>, -) -> Result, ZKVMError> +) -> Result<(ZKVMChipProof, MainSumcheckEvals, Point), ZKVMError> where E: ExtensionField, PCS: PolynomialCommitmentScheme + 'static, diff --git a/ceno_zkvm/src/scheme/scheduler.rs b/ceno_zkvm/src/scheme/scheduler.rs index 060f91083..c6cd999cb 100644 --- a/ceno_zkvm/src/scheme/scheduler.rs +++ b/ceno_zkvm/src/scheme/scheduler.rs @@ -16,7 +16,7 @@ use crate::{ error::ZKVMError, scheme::{ ZKVMChipProof, - hal::{MainSumcheckEvals, ProofInput}, + hal::{MainConstraintJob, MainSumcheckEvals, ProofInput}, }, structs::ProvingKey, }; @@ -100,32 +100,34 @@ pub struct ChipTask<'a, PB: ProverBackend> { } /// Result from a completed chip proof task -pub struct ChipTaskResult { +pub struct ChipTaskResult<'a, PB: ProverBackend> { /// Task ID for ordering pub task_id: usize, /// Circuit index for proof collection pub circuit_idx: usize, /// The generated proof - pub proof: ZKVMChipProof, + pub proof: ZKVMChipProof, /// Prover-only opening evaluations split by witness/fixed/pi domains. - pub opening_evals: MainSumcheckEvals, + pub opening_evals: MainSumcheckEvals, /// Opening point for this proof - pub input_opening_point: Point, + pub input_opening_point: Point, + /// Deferred main-constraint proving job. + pub main_constraint_job: Option>, /// Whether this circuit has witness or fixed polynomials pub has_witness_or_fixed: bool, } /// Message sent from worker to scheduler on task completion #[cfg(feature = "gpu")] -struct CompletionMessage { +struct CompletionMessage<'a, PB: ProverBackend> { /// The result of the proof - result: Result, ZKVMError>, + result: Result, ZKVMError>, /// Memory that was reserved for this task (to release) memory_reserved: u64, /// Task ID for ordering task_id: usize, /// Sampled value from the forked transcript (for gather phase) - forked_sample: E, + forked_sample: PB::E, } /// Memory-aware parallel chip proof scheduler @@ -152,12 +154,12 @@ impl ChipScheduler { tasks: Vec>, transcript: &T, execute_task: F, - ) -> Result<(Vec>, Vec), ZKVMError> + ) -> Result<(Vec>, Vec), ZKVMError> where PB: ProverBackend + 'static, PB::E: Send + 'static, T: Transcript + Clone, - F: Fn(ChipTask<'a, PB>, &mut T) -> Result, ZKVMError> + Send + Sync, + F: Fn(ChipTask<'a, PB>, &mut T) -> Result, ZKVMError> + Send + Sync, { #[cfg(feature = "gpu")] { @@ -188,12 +190,12 @@ impl ChipScheduler { tasks: Vec>, parent_transcript: &T, execute_task: F, - ) -> Result<(Vec>, Vec), ZKVMError> + ) -> Result<(Vec>, Vec), ZKVMError> where PB: ProverBackend + 'static, PB::E: Send + 'static, T: Transcript + Clone, - F: Fn(ChipTask<'a, PB>, &mut T) -> Result, ZKVMError>, + F: Fn(ChipTask<'a, PB>, &mut T) -> Result, ZKVMError>, { if tasks.is_empty() { return Ok((vec![], vec![])); @@ -253,12 +255,12 @@ impl ChipScheduler { mut tasks: Vec>, transcript: &T, execute_task: F, - ) -> Result<(Vec>, Vec), ZKVMError> + ) -> Result<(Vec>, Vec), ZKVMError> where PB: ProverBackend + 'static, PB::E: Send + 'static, T: Transcript + Clone, - F: Fn(ChipTask<'a, PB>, &mut T) -> Result, ZKVMError> + Send + Sync, + F: Fn(ChipTask<'a, PB>, &mut T) -> Result, ZKVMError> + Send + Sync, { if tasks.is_empty() { return Ok((vec![], vec![])); @@ -311,15 +313,15 @@ impl ChipScheduler { // Worker -> Scheduler: CompletionMessage (includes sampled value) let (task_tx, task_rx) = mpsc::channel::>(); let task_rx = Arc::new(Mutex::new(task_rx)); - let (done_tx, done_rx) = mpsc::channel::>(); + let (done_tx, done_rx) = mpsc::channel::>(); // 3. State tracking let mut tasks_inflight = 0usize; - let mut results: Vec> = Vec::with_capacity(total_tasks); + let mut results: Vec> = Vec::with_capacity(total_tasks); let mut samples: Vec<(usize, PB::E)> = Vec::with_capacity(total_tasks); // Helper to handle a completion message - let mut handle_completion = |msg: CompletionMessage, + let mut handle_completion = |msg: CompletionMessage<'a, PB>, mem_pool: &ceno_gpu::common::mem_pool::CudaMemPool, tasks_inflight: &mut usize, label: &str| diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 568d548b6..329cd9835 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -39,10 +39,7 @@ use super::{ utils::infer_tower_product_witness, verifier::{TowerVerify, ZKVMVerifier}, }; -use crate::{ - e2e::ShardContext, scheme::constants::NUM_FANIN, structs::PointAndEval, - tables::DynamicRangeTableCircuit, -}; +use crate::{e2e::ShardContext, tables::DynamicRangeTableCircuit}; use itertools::Itertools; use mpcs::{ PolynomialCommitmentScheme, SecurityLevel, SecurityLevel::Conjecture100bits, WhirDefault, @@ -174,7 +171,6 @@ fn test_rw_lk_expression_combination() { zkvm_fixed_traces, ) .unwrap(); - let vk = pk.get_vk_slow(); // generate mock witness let num_instances = 1 << 8; @@ -248,7 +244,7 @@ fn test_rw_lk_expression_combination() { num_instances: [num_instances, 0], has_ecc_ops: false, }; - let task = crate::scheme::scheduler::ChipTask { + let mut task = crate::scheme::scheduler::ChipTask { task_id: 0, circuit_name: name.clone(), circuit_idx: 0, @@ -264,12 +260,10 @@ fn test_rw_lk_expression_combination() { num_witin: 0, structural_rmm: None, }; - let (proof, _, _) = prover - .create_chip_proof(&task, &mut transcript) + let (_proof, _main_job) = prover + .create_chip_proof(&mut task, &mut transcript) .expect("create_proof failed"); - // verify proof - let verifier = ZKVMVerifier::new(vk.clone()); let mut v_transcript = BasicTranscript::new(b"test"); // write commitment into transcript and derive challenges from it Pcs::write_commitment(&witin_commit, &mut v_transcript).unwrap(); @@ -283,18 +277,6 @@ fn test_rw_lk_expression_combination() { { Instrumented::<<::BaseField as PoseidonField>::P>::clear_metrics(); } - let _ = verifier - .verify_chip_proof( - name.as_str(), - verifier.vk.circuit_vks.get(&name).unwrap(), - &proof, - &PublicValues::default(), - &mut v_transcript, - NUM_FANIN, - &PointAndEval::default(), - &verifier_challenges, - ) - .expect("verifier failed"); #[cfg(debug_assertions)] { println!( diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index e45cae99d..6a51ed550 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -8,7 +8,7 @@ use std::{ #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; -use super::{PublicValues, ZKVMChipProof, ZKVMProof}; +use super::{MainConstraintProof, PublicValues, ZKVMChipProof, ZKVMProof}; use crate::{ error::ZKVMError, instructions::riscv::constants::{ @@ -18,7 +18,10 @@ use crate::{ scheme::{ constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, septic_curve::{SepticExtension, SepticPoint}, - utils::{assign_group_evals, derive_ecc_bridge_claims}, + utils::{ + GkrOutputStageMask, assign_group_evals, derive_ecc_bridge_claims, + first_layer_output_group_stage_masks, + }, }, structs::{ ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs, VerifyingKey, @@ -29,19 +32,31 @@ use ceno_emul::{FullTracer as Tracer, WORD_SIZE}; use gkr_iop::{ self, selector::{SelectorContext, SelectorType}, + utils::{ + eval_inner_repeated_incremental_vec, eval_outer_repeated_incremental_vec, + eval_stacked_constant_vec, eval_stacked_wellform_address_vec, eval_wellform_address_vec, + }, }; use itertools::{Itertools, chain, interleave, izip}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ Expression, StructuralWitIn, - StructuralWitInType::StackedConstantSequence, - mle::IntoMLE, + StructuralWitInType::{ + Empty, EqualDistanceDynamicSequence, EqualDistanceSequence, + InnerRepeatingIncrementalSequence, OuterRepeatingIncrementalSequence, + StackedConstantSequence, StackedIncrementalSequence, + }, + mle::{IntoMLE, MultilinearExtension}, + monomial::Term, util::ceil_log2, + utils::eval_by_expr_with_instance, virtual_poly::{VPAuxInfo, build_eq_x_r_vec_sequential, eq_eval}, + virtual_polys::VirtualPolynomials, }; -use p3::field::FieldAlgebra; +use p3::field::{FieldAlgebra, dot_product}; use sumcheck::{ - structs::{IOPProof, IOPVerifierState}, + frontload, + structs::{IOPProof, IOPVerifierState, SumCheckSubClaim}, util::get_challenge_pows, }; use transcript::{ForkableTranscript, Transcript}; @@ -49,6 +64,8 @@ use witness::next_pow2_instance_padding; pub use crate::structs::RV32imMemStateConfig; +type BatchedMainOpeningEvals = Vec<(Point, Vec, Vec)>; + pub struct ZKVMVerifier< E: ExtensionField, PCS: PolynomialCommitmentScheme, @@ -59,6 +76,147 @@ pub struct ZKVMVerifier< pub vk: ZKVMVerifyingKey, } +pub(crate) struct PendingMainConstraintVerification<'a, E: ExtensionField> { + circuit_name: &'a str, + circuit_vk: &'a VerifyingKey, + proof: &'a ZKVMChipProof, + num_var_with_rotation: usize, + out_evals: Vec>, + pi: Vec, + selector_ctxs: Vec, +} + +fn validate_batched_main_structural_evals( + pending: &PendingMainConstraintVerification<'_, E>, + layer: &gkr_iop::gkr::layer::Layer, + eval_and_dedup_points: &[(Vec, Option>)], + layer_evals: &[E], + in_point: &Point, +) -> Result<(), String> { + let structural_witin_offset = layer.n_witin + layer.n_fixed; + for (((sel_type, _), (_, out_point)), selector_ctx) in layer + .out_sel_and_eval_exprs + .iter() + .zip(eval_and_dedup_points.iter()) + .zip(pending.selector_ctxs.iter()) + { + if let Some((expected_eval, wit_id)) = + sel_type.evaluate(out_point.as_ref().unwrap(), in_point, selector_ctx) + { + let wit_id = wit_id as usize + structural_witin_offset; + if layer_evals[wit_id] != expected_eval { + return Err(format!( + "{} selector structural witin mismatch wit_id={wit_id} expected={expected_eval} got={}", + pending.circuit_name, layer_evals[wit_id] + )); + } + } + } + + for StructuralWitIn { id, witin_type } in &layer.structural_witins { + let wit_id = *id as usize + structural_witin_offset; + let expected_eval = match witin_type { + EqualDistanceSequence { + offset, + multi_factor, + descending, + .. + } => eval_wellform_address_vec( + *offset as u64, + *multi_factor as u64, + in_point, + *descending, + ), + EqualDistanceDynamicSequence { + offset_instance_id, + multi_factor, + descending, + .. + } => { + let offset = pending.pi[*offset_instance_id as usize].to_canonical_u64(); + eval_wellform_address_vec(offset, *multi_factor as u64, in_point, *descending) + } + StackedIncrementalSequence { .. } => eval_stacked_wellform_address_vec(in_point), + StackedConstantSequence { .. } => eval_stacked_constant_vec(in_point), + InnerRepeatingIncrementalSequence { k, .. } => { + eval_inner_repeated_incremental_vec(*k as u64, in_point) + } + OuterRepeatingIncrementalSequence { k, .. } => { + eval_outer_repeated_incremental_vec(*k as u64, in_point) + } + Empty => continue, + }; + if expected_eval != layer_evals[wit_id] { + return Err(format!( + "{} structural witin mismatch", + pending.circuit_name + )); + } + } + + Ok(()) +} + +fn eval_batched_main_frontload_terms( + layer_evals: &[E], + pi: &[E], + challenges: &[E], + global_in_point: &[E], + num_var_with_rotation: usize, + terms: &[Term, Expression>], +) -> E { + let evaluated_terms = terms + .iter() + .map(|term| { + let scalar = eval_by_expr_with_instance(&[], &[], &[], pi, challenges, &term.scalar); + let product_wit_ids = term + .product + .iter() + .map(|expr| { + let Expression::WitIn(wit_id) = expr else { + panic!("main monomial product must be converted to WitIn") + }; + *wit_id as usize + }) + .collect_vec(); + (scalar, product_wit_ids) + }) + .collect_vec(); + + let constant_mles = evaluated_terms + .iter() + .flat_map(|(_, product_wit_ids)| { + product_wit_ids.iter().map(|wit_id| { + MultilinearExtension::from_evaluations_ext_vec(0, vec![layer_evals[*wit_id]]) + }) + }) + .collect_vec(); + + let mut raw_mle_evals = Vec::with_capacity(constant_mles.len()); + let mut mle_index = 0usize; + let monomial_terms = evaluated_terms + .into_iter() + .map(|(scalar, product_wit_ids)| { + let product = product_wit_ids + .into_iter() + .map(|wit_id| { + let mle = &constant_mles[mle_index]; + mle_index += 1; + raw_mle_evals.push(layer_evals[wit_id]); + Either::Left(mle) + }) + .collect_vec(); + Term { scalar, product } + }) + .collect_vec(); + + let tail_point = &global_in_point[num_var_with_rotation..]; + let (mut polys, _) = + VirtualPolynomials::new_from_monimials(1, tail_point.len(), monomial_terms) + .get_batched_polys(); + frontload::evaluate(&polys.remove(0), tail_point, &raw_mle_evals) +} + fn bind_active_tower_eval_round( transcript: &mut impl Transcript, tower_proofs: &TowerProofs, @@ -160,39 +318,6 @@ impl> Ok((next_heap_addr_end, next_hint_addr_end)) } - #[allow(clippy::type_complexity)] - fn split_input_opening_evals( - circuit_vk: &VerifyingKey, - proof: &ZKVMChipProof, - ) -> Result<(Vec, Vec), ZKVMError> { - let cs = circuit_vk.get_cs(); - let Some(gkr_proof) = proof.gkr_iop_proof.as_ref() else { - return Err(ZKVMError::InvalidProof("missing gkr proof".into())); - }; - let Some(last_layer) = gkr_proof.0.last() else { - return Err(ZKVMError::InvalidProof("empty gkr proof layers".into())); - }; - - let evals = &last_layer.main.evals; - let wit_len = cs.num_witin(); - let fixed_len = cs.num_fixed(); - let min_len = wit_len + fixed_len; - if evals.len() < min_len { - return Err(ZKVMError::InvalidProof( - format!( - "insufficient main evals: {} < required {}", - evals.len(), - min_len - ) - .into(), - )); - } - - let wits_in_evals = evals[..wit_len].to_vec(); - let fixed_in_evals = evals[wit_len..(wit_len + fixed_len)].to_vec(); - Ok((wits_in_evals, fixed_in_evals)) - } - /// Verify a full zkVM trace from program entry to halt. /// /// This is the production verifier API. It treats a single proof as a @@ -468,6 +593,7 @@ impl> } // fork transcript to support chip concurrently proved + let mut pending_main_constraints = Vec::with_capacity(num_proofs); let mut forked_transcripts = transcript.fork(num_proofs); for ((index, proof), transcript) in vm_proof .chip_proofs @@ -585,29 +711,17 @@ impl> // accumulate logup_sum logup_sum += chip_logup_sum; - let (input_opening_point, chip_shard_ec_sum, wits_in_evals, fixed_in_evals) = self - .verify_chip_proof( - circuit_name, - circuit_vk, - proof, - &vm_proof.public_values, - transcript, - NUM_FANIN, - &point_eval, - &challenges, - )?; - if circuit_vk.get_cs().num_witin() > 0 { - witin_openings.push(( - input_opening_point.len(), - (input_opening_point.clone(), wits_in_evals), - )); - } - if circuit_vk.get_cs().num_fixed() > 0 { - fixed_openings.push(( - input_opening_point.len(), - (input_opening_point.clone(), fixed_in_evals), - )); - } + let (pending_main_constraint, chip_shard_ec_sum) = self.verify_chip_proof_pre_main( + circuit_name, + circuit_vk, + proof, + &vm_proof.public_values, + transcript, + NUM_FANIN, + &point_eval, + &challenges, + )?; + pending_main_constraints.push(pending_main_constraint); prod_w *= proof.w_out_evals.iter().flatten().copied().product::(); prod_r *= proof.r_out_evals.iter().flatten().copied().product::(); tracing::debug!( @@ -637,6 +751,28 @@ impl> transcript.append_field_element_ext(&sample); } + for (input_opening_point, wits_in_evals, fixed_in_evals) in self + .verify_batched_main_constraints( + pending_main_constraints, + &vm_proof.main_constraint_proof, + &mut transcript, + &challenges, + )? + { + if !wits_in_evals.is_empty() { + witin_openings.push(( + input_opening_point.len(), + (input_opening_point.clone(), wits_in_evals), + )); + } + if !fixed_in_evals.is_empty() { + fixed_openings.push(( + input_opening_point.len(), + (input_opening_point.clone(), fixed_in_evals), + )); + } + } + // verify mpcs let mut rounds = vec![(vm_proof.witin_commit.clone(), witin_openings)]; @@ -676,17 +812,23 @@ impl> /// verify proof and return input opening point #[allow(clippy::too_many_arguments, clippy::type_complexity)] - pub fn verify_chip_proof( + pub(crate) fn verify_chip_proof_pre_main<'a>( &self, - _name: &str, - circuit_vk: &VerifyingKey, - proof: &ZKVMChipProof, + _name: &'a str, + circuit_vk: &'a VerifyingKey, + proof: &'a ZKVMChipProof, public_values: &PublicValues, transcript: &mut impl Transcript, num_product_fanin: usize, _out_evals: &PointAndEval, challenges: &[E; 2], // derive challenge from PCS - ) -> Result<(Point, Option>, Vec, Vec), ZKVMError> { + ) -> Result< + ( + PendingMainConstraintVerification<'a, E>, + Option>, + ), + ZKVMError, + > { let composed_cs = circuit_vk.get_cs(); let ComposedConstrainSystem { zkvm_v1_css: cs, @@ -841,11 +983,13 @@ impl> let first_layer = gkr_circuit.layers.first().ok_or_else(|| { ZKVMError::InvalidProof(format!("{_name} empty gkr circuit layers").into()) })?; + let group_stage_masks = first_layer_output_group_stage_masks(composed_cs, gkr_circuit); let selector_ctxs = first_layer .out_sel_and_eval_exprs .iter() - .map(|(selector, _)| { - if cs.ec_final_sum.is_empty() { + .zip_eq(group_stage_masks.iter()) + .map(|((selector, _), stage_mask)| { + if !stage_mask.contains(GkrOutputStageMask::TOWER) || cs.ec_final_sum.is_empty() { SelectorContext::new(0, num_instances, num_var_with_rotation) } else if cs.r_selector.as_ref() == Some(selector) { SelectorContext::new(0, proof.num_instances[0], num_var_with_rotation) @@ -921,76 +1065,261 @@ impl> ); } - if let Some(ecc_proof) = proof.ecc_proof.as_ref() { - let Some( - [ - x_group_idx, - y_group_idx, - slope_group_idx, - x3_group_idx, - y3_group_idx, - ], - ) = first_layer.ecc_bridge_group_indices() - else { + let pi = cs + .instance + .iter() + .map(|instance| E::from(public_values.query_by_index::(instance.0))) + .collect_vec(); + Ok(( + PendingMainConstraintVerification { + circuit_name: _name, + circuit_vk, + proof, + num_var_with_rotation, + out_evals, + pi, + selector_ctxs, + }, + shard_ec_sum, + )) + } + + fn verify_batched_main_constraints( + &self, + pending_main_constraints: Vec>, + main_constraint_proof: &MainConstraintProof, + transcript: &mut impl Transcript, + challenges: &[E; 2], + ) -> Result, ZKVMError> { + if pending_main_constraints.is_empty() { + if !main_constraint_proof.proof.proof.proofs.is_empty() + || !main_constraint_proof.proof.evals.is_empty() + { return Err(ZKVMError::InvalidProof( - "ecc bridge claims expected but selectors are missing".into(), + "empty main constraints with non-empty proof".into(), )); - }; + } + return Ok(vec![]); + } - let sample_r = transcript.sample_and_append_vec(b"ecc_gkr_bridge_r", 1)[0]; - let claims = derive_ecc_bridge_claims(ecc_proof, sample_r, num_var_with_rotation)?; + struct PendingLayer<'a, E: ExtensionField> { + pending: PendingMainConstraintVerification<'a, E>, + layer: &'a gkr_iop::gkr::layer::Layer, + eval_and_dedup_points: Vec<(Vec, Option>)>, + eval_start: usize, + eval_len: usize, + alpha_start: usize, + } - assign_group_evals( - &mut out_evals, - &first_layer.out_sel_and_eval_exprs[x_group_idx].1, - &claims.x_evals, - &claims.xy_point, - ); - assign_group_evals( - &mut out_evals, - &first_layer.out_sel_and_eval_exprs[y_group_idx].1, - &claims.y_evals, - &claims.xy_point, - ); - assign_group_evals( - &mut out_evals, - &first_layer.out_sel_and_eval_exprs[slope_group_idx].1, - &claims.s_evals, - &claims.s_point, - ); - assign_group_evals( - &mut out_evals, - &first_layer.out_sel_and_eval_exprs[x3_group_idx].1, - &claims.x3_evals, - &claims.x3y3_point, - ); - assign_group_evals( - &mut out_evals, - &first_layer.out_sel_and_eval_exprs[y3_group_idx].1, - &claims.y3_evals, - &claims.x3y3_point, - ); + let mut layers = Vec::with_capacity(pending_main_constraints.len()); + let mut total_exprs = 0usize; + let mut total_evals = 0usize; + let mut max_num_variables = 0usize; + let mut max_degree = 0usize; + + for pending in pending_main_constraints { + let gkr_circuit = pending + .circuit_vk + .get_cs() + .gkr_circuit + .as_ref() + .ok_or_else(|| { + ZKVMError::InvalidProof( + format!("{} missing gkr circuit in vk", pending.circuit_name).into(), + ) + })?; + let layer = gkr_circuit.layers.first().ok_or_else(|| { + ZKVMError::InvalidProof( + format!("{} empty gkr circuit layers", pending.circuit_name).into(), + ) + })?; + max_num_variables = max_num_variables.max(pending.num_var_with_rotation); + max_degree = max_degree.max(layer.max_expr_degree + 1); + + let mut out_evals = pending.out_evals.clone(); + if let Some(ecc_proof) = pending.proof.ecc_proof.as_ref() { + let Some( + [ + x_group_idx, + y_group_idx, + slope_group_idx, + x3_group_idx, + y3_group_idx, + ], + ) = layer.ecc_bridge_group_indices() + else { + return Err(ZKVMError::InvalidProof( + "ecc bridge claims expected but selectors are missing".into(), + )); + }; + + let sample_r = transcript.sample_and_append_vec(b"ecc_gkr_bridge_r", 1)[0]; + let claims = + derive_ecc_bridge_claims(ecc_proof, sample_r, pending.num_var_with_rotation)?; + + assign_group_evals( + &mut out_evals, + &layer.out_sel_and_eval_exprs[x_group_idx].1, + &claims.x_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &layer.out_sel_and_eval_exprs[y_group_idx].1, + &claims.y_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &layer.out_sel_and_eval_exprs[slope_group_idx].1, + &claims.s_evals, + &claims.s_point, + ); + assign_group_evals( + &mut out_evals, + &layer.out_sel_and_eval_exprs[x3_group_idx].1, + &claims.x3_evals, + &claims.x3y3_point, + ); + assign_group_evals( + &mut out_evals, + &layer.out_sel_and_eval_exprs[y3_group_idx].1, + &claims.y3_evals, + &claims.x3y3_point, + ); + } + + out_evals.resize(gkr_circuit.n_evaluations, PointAndEval::default()); + let eval_and_dedup_points = layer + .out_sel_and_eval_exprs + .iter() + .map(|(_, out_eval_exprs)| { + let evals = out_eval_exprs + .iter() + .map(|out_eval| out_eval.evaluate(&out_evals, challenges).eval) + .collect_vec(); + let point = out_eval_exprs + .first() + .map(|out_eval| out_eval.evaluate(&out_evals, challenges).point); + (evals, point) + }) + .collect_vec(); + + let eval_len = layer.n_witin + layer.n_fixed + layer.n_structural_witin; + layers.push(PendingLayer { + pending, + layer, + eval_and_dedup_points, + eval_start: total_evals, + eval_len, + alpha_start: total_exprs, + }); + total_evals += eval_len; + total_exprs += layer.exprs.len(); } - let pi = cs - .instance + let main_evals = &main_constraint_proof.proof.evals; + if main_evals.len() != total_evals { + return Err(ZKVMError::InvalidProof( + format!( + "main constraint eval length mismatch: {} != {}", + main_evals.len(), + total_evals + ) + .into(), + )); + } + + let alpha_pows = get_challenge_pows(total_exprs, transcript); + let sigma = layers .iter() - .map(|instance| E::from(public_values.query_by_index::(instance.0))) - .collect_vec(); - let (wits_in_evals, fixed_in_evals) = Self::split_input_opening_evals(circuit_vk, proof)?; - let gkr_iop_proof = proof.gkr_iop_proof.clone().ok_or_else(|| { - ZKVMError::InvalidProof(format!("{_name} missing gkr iop proof").into()) - })?; - let (_, rt) = gkr_circuit.verify( - num_var_with_rotation, - gkr_iop_proof, - &out_evals, - &pi, - challenges, + .map(|pending_layer| { + let alpha = &alpha_pows[pending_layer.alpha_start + ..pending_layer.alpha_start + pending_layer.layer.exprs.len()]; + dot_product( + alpha.iter().copied(), + pending_layer + .eval_and_dedup_points + .iter() + .flat_map(|(sigmas, _)| sigmas) + .copied(), + ) + }) + .sum::(); + + let SumCheckSubClaim { + point: global_in_point, + expected_evaluation, + } = IOPVerifierState::verify( + sigma, + &main_constraint_proof.proof.proof, + &VPAuxInfo { + max_degree, + max_num_variables, + phantom: PhantomData, + }, transcript, - &selector_ctxs, - )?; - Ok((rt, shard_ec_sum, wits_in_evals, fixed_in_evals)) + ); + let global_in_point = global_in_point + .into_iter() + .map(|challenge| challenge.elements) + .collect_vec(); + transcript.append_field_element_exts(main_evals); + + let mut got_claim = E::ZERO; + let mut results = Vec::with_capacity(layers.len()); + for pending_layer in &layers { + let in_point = global_in_point[..pending_layer.pending.num_var_with_rotation].to_vec(); + let layer_evals = &main_evals + [pending_layer.eval_start..pending_layer.eval_start + pending_layer.eval_len]; + + validate_batched_main_structural_evals( + &pending_layer.pending, + pending_layer.layer, + &pending_layer.eval_and_dedup_points, + layer_evals, + &in_point, + ) + .map_err(|err| ZKVMError::InvalidProof(err.into()))?; + + let main_sumcheck_challenges = chain!( + challenges.iter().copied(), + alpha_pows[pending_layer.alpha_start + ..pending_layer.alpha_start + pending_layer.layer.exprs.len()] + .iter() + .copied() + ) + .collect_vec(); + got_claim += eval_batched_main_frontload_terms( + layer_evals, + &pending_layer.pending.pi, + &main_sumcheck_challenges, + &global_in_point, + pending_layer.pending.num_var_with_rotation, + pending_layer + .layer + .main_sumcheck_expression_monomial_terms + .as_ref() + .unwrap(), + ); + + results.push(( + in_point, + layer_evals[..pending_layer.layer.n_witin].to_vec(), + layer_evals[pending_layer.layer.n_witin + ..pending_layer.layer.n_witin + pending_layer.layer.n_fixed] + .to_vec(), + )); + } + + if got_claim != expected_evaluation { + return Err(ZKVMError::InvalidProof( + format!("main constraint claim mismatch: {expected_evaluation} != {got_claim}") + .into(), + )); + } + + Ok(results) } } diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 1521f1f83..f28a4a29b 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -739,9 +739,9 @@ mod tests { circuit_builder::{CircuitBuilder, ConstraintSystem}, scheme::{ PublicValues, constants::SEPTIC_EXTENSION_DEGREE, create_backend, create_prover, - hal::ProofInput, prover::ZKVMProver, septic_curve::SepticPoint, verifier::ZKVMVerifier, + hal::ProofInput, prover::ZKVMProver, septic_curve::SepticPoint, }, - structs::{ComposedConstrainSystem, PointAndEval, ProgramParams, RAMType, ZKVMProvingKey}, + structs::{ComposedConstrainSystem, ProgramParams, RAMType, ZKVMProvingKey}, tables::{ShardRamCircuit, ShardRamInput, ShardRamRecord, TableCircuit}, }; #[cfg(feature = "gpu")] @@ -873,7 +873,6 @@ mod tests { let pd = create_prover(backend); let zkvm_pk = ZKVMProvingKey::new(pp, vp); - let zkvm_vk = zkvm_pk.get_vk_slow(); let zkvm_prover = ZKVMProver::new(zkvm_pk.into(), pd); let mut transcript = BasicTranscript::new(b"global chip test"); @@ -919,7 +918,7 @@ mod tests { }; let mut rng = thread_rng(); let challenges = [E::random(&mut rng), E::random(&mut rng)]; - let task = crate::scheme::scheduler::ChipTask { + let mut task = crate::scheme::scheduler::ChipTask { task_id: 0, circuit_name: ShardRamCircuit::::name(), circuit_idx: 0, @@ -935,24 +934,8 @@ mod tests { num_witin: 0, structural_rmm: None, }; - let (proof, _, point) = zkvm_prover - .create_chip_proof(&task, &mut transcript) + let (_proof, _main_job) = zkvm_prover + .create_chip_proof(&mut task, &mut transcript) .unwrap(); - - let mut transcript = BasicTranscript::new(b"global chip test"); - let verifier = ZKVMVerifier::new(zkvm_vk); - let (vrf_point, _, _, _) = verifier - .verify_chip_proof( - "global", - &pk.vk, - &proof, - &public_value, - &mut transcript, - 2, - &PointAndEval::default(), - &challenges, - ) - .expect("verify global chip proof"); - assert_eq!(vrf_point, point); } } diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index be1727fe2..7d0f66592 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -660,7 +660,7 @@ fn log_common_term_plan_stats( let naive_mul_count: usize = term_factor_counts.iter().sum(); let coverage_percentage = (shared_terms as f64 / total_terms.max(1) as f64) * 100.0; let factored_percentage = (factored_terms as f64 / total_terms.max(1) as f64) * 100.0; - tracing::info!( + tracing::debug!( target: "gkr::layer", "[CommonFactoredTermPlan] gkr::layer {} groups={} shared_terms={}/{} ({coverage_percentage:.2}%) factored_terms={}/{} ({factored_percentage:.2}%) common_wit_range=[{}, {}] naive_mul={} factored_mul={}", layer_name,