Skip to content
Merged
7 changes: 7 additions & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ jobs:
RUSTFLAGS: "-C opt-level=3"
run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno examples/target/riscv32im-ceno-zkvm-elf/release/examples/bn254_curve_syscalls

- name: Run fibonacci (release) in 3 shards with CENO_CROSS_SHARD_LIMIT
env:
RUST_LOG: debug
RUSTFLAGS: "-C opt-level=3"
CENO_CROSS_SHARD_LIMIT: 32
run: cargo run --release --package ceno_zkvm --features sanity-check --bin e2e -- --platform=ceno --min-cycle-per-shard=10 --max-cycle-per-shard=20000 --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/release/examples/fibonacci

- name: Install cargo make
run: |
cargo make --version || cargo install cargo-make
Expand Down
45 changes: 34 additions & 11 deletions ceno_zkvm/src/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
hal::ProverDevice,
mock_prover::{LkMultiplicityKey, MockProver},
prover::ZKVMProver,
septic_curve::SepticPoint,
verifier::ZKVMVerifier,
},
state::GlobalState,
Expand Down Expand Up @@ -44,6 +45,7 @@ use witness::next_pow2_instance_padding;

pub const DEFAULT_MIN_CYCLE_PER_SHARDS: Cycle = 1 << 24;
pub const DEFAULT_MAX_CYCLE_PER_SHARDS: Cycle = 1 << 27;
pub const DEFAULT_CROSS_SHARD_ACCESS_LIMIT: usize = 1 << 20;

/// The polynomial commitment scheme kind
#[derive(
Expand Down Expand Up @@ -175,11 +177,16 @@ pub struct ShardContext<'a> {
Either<Vec<BTreeMap<WordAddr, RAMRecord>>, &'a mut BTreeMap<WordAddr, RAMRecord>>,
pub cur_shard_cycle_range: std::ops::Range<usize>,
pub expected_inst_per_shard: usize,
pub max_num_cross_shard_accesses: usize,
}

impl<'a> Default for ShardContext<'a> {
fn default() -> Self {
let max_threads = max_usable_threads();
let max_num_cross_shard_accesses = std::env::var("CENO_CROSS_SHARD_LIMIT")
.map(|v| v.parse().unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT))
.unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT);

Self {
shard_id: 0,
num_shards: 1,
Expand All @@ -202,6 +209,7 @@ impl<'a> Default for ShardContext<'a> {
),
cur_shard_cycle_range: Tracer::SUBCYCLES_PER_INSN as usize..usize::MAX,
expected_inst_per_shard: usize::MAX,
max_num_cross_shard_accesses,
}
}
}
Expand Down Expand Up @@ -231,6 +239,10 @@ impl<'a> ShardContext<'a> {
let subcycle_per_insn = Tracer::SUBCYCLES_PER_INSN as usize;
let max_threads = max_usable_threads();

let max_num_cross_shard_accesses = std::env::var("CENO_CROSS_SHARD_LIMIT")
.map(|v| v.parse().unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT))
.unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT);

// strategies
// 0. set cur_num_shards = num_provers
// 1. split instructions evenly by cur_num_shards
Expand Down Expand Up @@ -323,6 +335,7 @@ impl<'a> ShardContext<'a> {
),
cur_shard_cycle_range,
expected_inst_per_shard,
max_num_cross_shard_accesses,
}
})
.collect_vec()
Expand Down Expand Up @@ -355,6 +368,7 @@ impl<'a> ShardContext<'a> {
write_records_tbs: Either::Right(write),
cur_shard_cycle_range: self.cur_shard_cycle_range.clone(),
expected_inst_per_shard: self.expected_inst_per_shard,
max_num_cross_shard_accesses: self.max_num_cross_shard_accesses,
},
)
.collect_vec(),
Expand Down Expand Up @@ -1125,17 +1139,26 @@ pub fn generate_witness<'a, E: ExtensionField>(
pi.end_pc = current_shard_end_pc;
pi.end_cycle = current_shard_end_cycle;
// set shard ram bus expected output to pi
let shard_ram_witness = zkvm_witness.get_table_witness(&ShardRamCircuit::<E>::name());
if let Some(shard_ram_witness) = shard_ram_witness
&& shard_ram_witness[0].num_instances() > 0
{
for (f, v) in ShardRamCircuit::<E>::extract_ec_sum(
&system_config.mmu_config.ram_bus_circuit,
&shard_ram_witness[0],
)
.into_iter()
.zip_eq(pi.shard_rw_sum.as_mut_slice())
{
let shard_ram_witnesses = zkvm_witness.get_witness(&ShardRamCircuit::<E>::name());

if let Some(shard_ram_witnesses) = shard_ram_witnesses {
let shard_ram_ec_sum: SepticPoint<E::BaseField> = shard_ram_witnesses
.iter()
.filter(|shard_ram_witness| shard_ram_witness.num_instances[0] > 0)
.map(|shard_ram_witness| {
ShardRamCircuit::<E>::extract_ec_sum(
&system_config.mmu_config.ram_bus_circuit,
&shard_ram_witness.witness_rmms[0],
)
})
.sum();

let xy = shard_ram_ec_sum
.x
.0
.iter()
.chain(shard_ram_ec_sum.y.0.iter());
for (f, v) in xy.zip_eq(pi.shard_rw_sum.as_mut_slice()) {
*v = f.to_canonical_u64() as u32;
}
}
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ impl<E: ExtensionField> MmuConfig<'_, E> {
&self.local_final_circuit,
&(shard_ctx, all_records.as_slice()),
)?;
witness.assign_global_chip_circuit(
witness.assign_shared_circuit(
cs,
&(shard_ctx, all_records.as_slice()),
&self.ram_bus_circuit,
Expand Down
3 changes: 3 additions & 0 deletions ceno_zkvm/src/keygen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ impl<E: ExtensionField> ZKVMConstraintSystem<E> {
fixed_traces.insert(circuit_index, fixed_trace_rmm);
}

vm_pk
.circuit_name_to_index
.insert(c_name.clone(), circuit_index);
let circuit_pk = cs.key_gen();
assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none());
}
Expand Down
20 changes: 17 additions & 3 deletions ceno_zkvm/src/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::{
collections::{BTreeMap, HashMap},
fmt::{self, Debug},
iter,
ops::Div,
rc::Rc,
};
Expand Down Expand Up @@ -156,7 +157,8 @@ pub struct ZKVMProof<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> {
pub raw_pi: Vec<Vec<E::BaseField>>,
// the evaluation of raw_pi.
pub pi_evals: Vec<E>,
pub chip_proofs: BTreeMap<usize, ZKVMChipProof<E>>,
// each circuit may have multiple proof instances
pub chip_proofs: BTreeMap<usize, Vec<ZKVMChipProof<E>>>,
pub witin_commit: <PCS as PolynomialCommitmentScheme<E>>::Commitment,
pub opening_proof: PCS::Proof,
}
Expand All @@ -165,7 +167,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProof<E, PCS> {
pub fn new(
raw_pi: Vec<Vec<E::BaseField>>,
pi_evals: Vec<E>,
chip_proofs: BTreeMap<usize, ZKVMChipProof<E>>,
chip_proofs: BTreeMap<usize, Vec<ZKVMChipProof<E>>>,
witin_commit: <PCS as PolynomialCommitmentScheme<E>>::Commitment,
opening_proof: PCS::Proof,
) -> Self {
Expand Down Expand Up @@ -211,7 +213,13 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProof<E, PCS> {
let halt_instance_count = self
.chip_proofs
.get(&halt_circuit_index)
.map_or(0, |proof| proof.num_instances.iter().sum());
.map_or(0, |proofs| {
proofs
.iter()
.flat_map(|proof| &proof.num_instances)
.copied()
.sum()
});
if halt_instance_count > 0 {
assert_eq!(
halt_instance_count, 1,
Expand Down Expand Up @@ -240,6 +248,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E> + Serialize> fmt::Dis
let tower_proof = self
.chip_proofs
.iter()
.flat_map(|(circuit_index, proofs)| {
iter::repeat_n(circuit_index, proofs.len()).zip(proofs)
})
.map(|(circuit_index, proof)| {
let size = bincode::serialized_size(&proof.tower_proof);
size.inspect(|size| {
Expand All @@ -254,6 +265,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E> + Serialize> fmt::Dis
let main_sumcheck = self
.chip_proofs
.iter()
.flat_map(|(circuit_index, proofs)| {
iter::repeat_n(circuit_index, proofs.len()).zip(proofs)
})
.map(|(circuit_index, proof)| {
let size = bincode::serialized_size(&proof.main_sumcheck_proofs);
size.inspect(|size| {
Expand Down
41 changes: 12 additions & 29 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use p3::field::{Field, FieldAlgebra};
use rand::thread_rng;
use std::{
cmp::max,
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
collections::{BTreeSet, HashMap, HashSet},
fmt::Debug,
fs::File,
hash::Hash,
Expand Down Expand Up @@ -1004,21 +1004,13 @@ Hints:
let mut fixed_mles = HashMap::new();
let mut num_instances = HashMap::new();

let circuit_index_fixed_num_instances: BTreeMap<String, usize> = fixed_trace
.circuit_fixed_traces
.iter()
.map(|(circuit_name, rmm)| {
(
circuit_name.clone(),
rmm.as_ref().map(|rmm| rmm.num_instances()).unwrap_or(0),
)
})
.collect();
let mut lkm_tables = LkMultiplicityRaw::<E>::default();
let mut lkm_opcodes = LkMultiplicityRaw::<E>::default();

// Process all circuits.
for (circuit_name, composed_cs) in &cs.circuit_css {
for (circuit_name, chip_inputs) in &witnesses.witnesses {
let composed_cs = cs.circuit_css.get(circuit_name).unwrap();
// for (circuit_name, composed_cs) in &cs.circuit_css {
let ComposedConstrainSystem {
zkvm_v1_css: cs, ..
} = &composed_cs;
Expand All @@ -1037,30 +1029,21 @@ Hints:
continue;
}

let [witness, structural_witness] = witnesses
.get_opcode_witness(circuit_name)
.or_else(|| witnesses.get_table_witness(circuit_name))
.unwrap_or_else(|| panic!("witness for {} should not be None", circuit_name));
let num_rows = if witness.num_instances() > 0 {
witness.num_instances()
} else if structural_witness.num_instances() > 0 {
structural_witness.num_instances()
} else if composed_cs.is_static_circuit() {
circuit_index_fixed_num_instances
.get(circuit_name)
.copied()
.unwrap_or(0)
} else {
0
};
assert!(chip_inputs.len() <= 1, "TODO support > 1 chip_inputs");
let chip_input = chip_inputs.first().filter(|ci| ci.num_instances() > 0);

if num_rows == 0 {
if chip_input.is_none() {
wit_mles.insert(circuit_name.clone(), vec![]);
structural_wit_mles.insert(circuit_name.clone(), vec![]);
fixed_mles.insert(circuit_name.clone(), vec![]);
num_instances.insert(circuit_name.clone(), 0);
continue;
}

let chip_input = chip_input.unwrap();
let num_rows = chip_input.num_instances();

let [witness, structural_witness] = &chip_input.witness_rmms;
let mut witness = witness
.to_mles()
.into_iter()
Expand Down
Loading