Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions ceno_emul/src/syscalls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub use ceno_syscall::{
SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SECP256K1_SCALAR_INVERT, SECP256R1_ADD,
SECP256R1_DECOMPRESS, SECP256R1_DOUBLE, SHA_EXTEND, UINT256_MUL,
};
pub use sha256::ShaExtendState;

pub trait SyscallSpec {
const NAME: &'static str;
Expand Down
123 changes: 92 additions & 31 deletions ceno_emul/src/syscalls/sha256.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::{Change, EmuContext, Platform, Tracer, VMState, Word, WriteOp, utils::MemoryView};
use crate::{
ByteAddr, Change, Platform, Tracer, VMState, Word, WriteOp, rv32im::EmuContext,
utils::MemoryView,
};

use super::{SyscallEffects, SyscallSpec, SyscallWitness};

Expand All @@ -10,7 +13,7 @@ impl SyscallSpec for Sha256ExtendSpec {
const NAME: &'static str = "SHA256_EXTEND";

const REG_OPS_COUNT: usize = 1;
const MEM_OPS_COUNT: usize = SHA_EXTEND_WORDS;
const MEM_OPS_COUNT: usize = SHA_EXTEND_ROUND_MEM_OPS;
const CODE: u32 = ceno_syscall::SHA_EXTEND;
}

Expand All @@ -29,39 +32,97 @@ impl From<ShaExtendWords> for [Word; SHA_EXTEND_WORDS] {
}
}

/// Based on: https://github.com/succinctlabs/sp1/blob/2aed8fea16a67a5b2983ffc471b2942c2f2512c8/crates/core/machine/src/syscall/precompiles/sha256/extend/mod.rs#L22
pub fn sha_extend(w: &mut [u32]) {
for i in 16..64 {
let s0 = w[i - 15].rotate_right(7) ^ w[i - 15].rotate_right(18) ^ (w[i - 15] >> 3);
let s1 = w[i - 2].rotate_right(17) ^ w[i - 2].rotate_right(19) ^ (w[i - 2] >> 10);
w[i] = w[i - 16]
pub const SHA_EXTEND_ROUND_MEM_OPS: usize = 5;

pub struct ShaExtendState {
state_ptr: Word,
words: [Word; SHA_EXTEND_WORDS],
round: usize,
}

impl ShaExtendState {
pub fn new<T: Tracer>(vm: &VMState<T>) -> Self {
let state_ptr = vm.peek_register(Platform::reg_arg0());
let state_view = MemoryView::<_, SHA_EXTEND_WORDS>::new(vm, state_ptr);
let words = state_view.words();
Self {
state_ptr,
words,
round: 16,
}
}

pub fn is_done(&self) -> bool {
self.round >= SHA_EXTEND_WORDS
}

pub fn next_round_effects(&mut self) -> Option<SyscallEffects> {
if self.is_done() {
return None;
}

let i = self.round;
let w_i_minus_2 = self.words[i - 2];
let w_i_minus_7 = self.words[i - 7];
let w_i_minus_15 = self.words[i - 15];
let w_i_minus_16 = self.words[i - 16];

let s0 = w_i_minus_15.rotate_right(7) ^ w_i_minus_15.rotate_right(18) ^ (w_i_minus_15 >> 3);
let s1 = w_i_minus_2.rotate_right(17) ^ w_i_minus_2.rotate_right(19) ^ (w_i_minus_2 >> 10);
let new_word = w_i_minus_16
.wrapping_add(s0)
.wrapping_add(w[i - 7])
.wrapping_add(w_i_minus_7)
.wrapping_add(s1);
let old_word = self.words[i];

self.words[i] = new_word;
self.round += 1;

let base = ByteAddr::from(self.state_ptr).waddr();
let mem_ops = vec![
WriteOp {
addr: base + (i - 2) as u32,
value: Change::new(w_i_minus_2, w_i_minus_2),
previous_cycle: 0,
},
WriteOp {
addr: base + (i - 7) as u32,
value: Change::new(w_i_minus_7, w_i_minus_7),
previous_cycle: 0,
},
WriteOp {
addr: base + (i - 15) as u32,
value: Change::new(w_i_minus_15, w_i_minus_15),
previous_cycle: 0,
},
WriteOp {
addr: base + (i - 16) as u32,
value: Change::new(w_i_minus_16, w_i_minus_16),
previous_cycle: 0,
},
WriteOp {
addr: base + i as u32,
value: Change::new(old_word, new_word),
previous_cycle: 0,
},
];

let reg_ops = vec![WriteOp::new_register_op(
Platform::reg_arg0(),
Change::new(self.state_ptr, self.state_ptr),
0,
)];

Some(SyscallEffects {
witness: SyscallWitness::new(mem_ops, reg_ops),
next_pc: None,
})
}
}

pub fn extend<T: Tracer>(vm: &VMState<T>) -> SyscallEffects {
let state_ptr = vm.peek_register(Platform::reg_arg0());

// Read the argument `state_ptr`.
let reg_ops = vec![WriteOp::new_register_op(
Platform::reg_arg0(),
Change::new(state_ptr, state_ptr),
0, // Cycle set later in finalize().
)];

let mut state_view = MemoryView::<_, SHA_EXTEND_WORDS>::new(vm, state_ptr);
let mut sha_extend_words = ShaExtendWords::from(state_view.words());
sha_extend(&mut sha_extend_words.0);
let output_words: [Word; SHA_EXTEND_WORDS] = sha_extend_words.into();

state_view.write(output_words);
let mem_ops = state_view.mem_ops().to_vec();

assert_eq!(mem_ops.len(), SHA_EXTEND_WORDS);
SyscallEffects {
witness: SyscallWitness::new(mem_ops, reg_ops),
next_pc: None,
}
let mut state = ShaExtendState::new(vm);
state
.next_round_effects()
.expect("sha_extend requires at least one round")
}
73 changes: 70 additions & 3 deletions ceno_emul/src/vm_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ use crate::{
PC_STEP_SIZE, Program, WORD_SIZE,
addr::{ByteAddr, RegIdx, Word, WordAddr},
dense_addr_space::DenseAddrSpace,
encode_rv32,
platform::Platform,
rv32im::{Instruction, TrapCause},
syscalls::{SyscallEffects, handle_syscall},
rv32im::{InsnKind, Instruction, TrapCause},
syscalls::{SHA_EXTEND, ShaExtendState, SyscallEffects, handle_syscall},
tracer::{Change, FullTracer, Tracer},
};
use anyhow::{Result, anyhow};
Expand All @@ -15,6 +16,23 @@ pub struct HaltState {
pub exit_code: u32,
}

pub enum MultiCycleState {
ShaExtend(ShaExtendState),
}

impl MultiCycleState {
pub fn next_round_effects(&mut self) -> Option<SyscallEffects> {
match self {
MultiCycleState::ShaExtend(state) => state.next_round_effects(),
}
}
pub fn is_done(&self) -> bool {
match self {
MultiCycleState::ShaExtend(state) => state.is_done(),
}
}
}

/// An implementation of the machine state and of the side-effects of operations.
pub const VM_REG_COUNT: usize = 32 + 1;

Expand All @@ -29,6 +47,7 @@ pub struct VMState<T: Tracer = FullTracer> {
// Termination.
halt_state: Option<HaltState>,
tracer: T,
pending_multi_cycle_insn: Option<(Instruction, MultiCycleState)>,
}

impl VMState<FullTracer> {
Expand Down Expand Up @@ -60,6 +79,7 @@ impl<T: Tracer> VMState<T> {
registers: [0; VM_REG_COUNT],
halt_state: None,
tracer: T::new(&platform),
pending_multi_cycle_insn: None,
};

for (&addr, &value) in &program.image {
Expand Down Expand Up @@ -127,9 +147,40 @@ impl<T: Tracer> VMState<T> {
}

fn step(&mut self) -> Result<T::Record> {
if self.pending_multi_cycle_insn.is_some() {
let pc = self.pc;
let insn = self.pending_multi_cycle_insn.as_ref().unwrap().0;
self.tracer.fetch(ByteAddr(pc).waddr(), insn);
let ecall_code = self.peek_register(Platform::reg_ecall());
self.tracer.load_register(Platform::reg_ecall(), ecall_code);

let (_, state) = self.pending_multi_cycle_insn.as_mut().unwrap();
let mut effects = state.next_round_effects().ok_or_else(|| {
anyhow!("pending multi cycle instruction without remaining rounds")
})?;

let is_last_round = state.is_done();
effects.next_pc = if is_last_round {
Some(self.pc + PC_STEP_SIZE as u32)
} else {
Some(self.pc)
};
self.apply_syscall(effects)?;
self.tracer.store_pc(ByteAddr(self.pc));

let step = self.tracer.advance();
if is_last_round {
self.pending_multi_cycle_insn = None;
}
return Ok(step);
}

crate::rv32im::step(self)?;
let step = self.tracer.advance();
if self.tracer.is_busy_loop(&step) && !self.halted() {
if self.tracer.is_busy_loop(&step)
&& self.pending_multi_cycle_insn.is_none()
&& !self.halted()
{
Err(anyhow!("Stuck in loop {}", "{}"))
} else {
Ok(step)
Expand Down Expand Up @@ -173,6 +224,22 @@ impl<T: Tracer> EmuContext for VMState<T> {
tracing::debug!("halt with exit_code={}", exit_code);
self.halt(exit_code);
Ok(true)
} else if function == SHA_EXTEND {
if self.pending_multi_cycle_insn.is_some() {
return Err(anyhow!("nested sha_extend syscall"));
}

let mut state = ShaExtendState::new(self);
let mut effects = state
.next_round_effects()
.ok_or_else(|| anyhow!("sha_extend without rounds"))?;
effects.next_pc = Some(self.pc);
self.pending_multi_cycle_insn = Some((
encode_rv32(InsnKind::ECALL, 0, 0, 0, 0),
MultiCycleState::ShaExtend(state),
));
self.apply_syscall(effects)?;
Ok(true)
} else {
match handle_syscall(self, function) {
Ok(effects) => {
Expand Down
39 changes: 21 additions & 18 deletions ceno_host/tests/test_elf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::{collections::BTreeSet, iter::from_fn, sync::Arc};
use anyhow::Result;
use ceno_emul::{
BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS, CENO_PLATFORM, EmuContext, InsnKind,
Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, SHA_EXTEND_WORDS,
StepRecord, UINT256_WORDS_FIELD_ELEMENT, VMState, WORD_SIZE, Word, WordAddr, WriteOp,
Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, StepRecord,
UINT256_WORDS_FIELD_ELEMENT, VMState, WORD_SIZE, Word, WordAddr, WriteOp,
host_utils::{read_all_messages, read_all_messages_as_words},
};
use ceno_host::CenoStdin;
Expand Down Expand Up @@ -470,15 +470,7 @@ fn test_sha256_extend() -> Result<()> {

let steps = run(&mut state)?;
let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec();
assert_eq!(syscalls.len(), 1);

let witness = syscalls[0];
assert_eq!(witness.reg_ops.len(), 1);
assert_eq!(witness.reg_ops[0].register_index(), Platform::reg_arg0());

let state_ptr = witness.reg_ops[0].value.after;
assert_eq!(state_ptr, witness.reg_ops[0].value.before);
let state_ptr: WordAddr = state_ptr.into();
assert_eq!(syscalls.len(), 48);

let expected = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 34013193, 67559435, 1711661200,
Expand All @@ -491,14 +483,25 @@ fn test_sha256_extend() -> Result<()> {
634956631,
];

assert_eq!(witness.mem_ops.len(), SHA_EXTEND_WORDS);
for (round, witness) in syscalls.iter().enumerate() {
assert_eq!(witness.reg_ops.len(), 1);
assert_eq!(witness.reg_ops[0].register_index(), Platform::reg_arg0());

for (i, write_op) in witness.mem_ops.iter().enumerate() {
assert_eq!(write_op.addr, state_ptr + i);
assert_eq!(write_op.value.after, expected[i]);
if i < 16 {
// sanity check: first 16 entries remain unchanged
assert_eq!(write_op.value.before, write_op.value.after);
let state_ptr = witness.reg_ops[0].value.after;
assert_eq!(state_ptr, witness.reg_ops[0].value.before);
let state_ptr: WordAddr = state_ptr.into();

assert_eq!(witness.mem_ops.len(), 5);

let offsets = [-2, -7, -15, -16, 0];
for (i, write_op) in witness.mem_ops.iter().enumerate() {
let mem_round_id = round + (16 + offsets[i]) as usize;
assert_eq!(write_op.addr, state_ptr + mem_round_id as u32);
if i < 4 {
assert_eq!(write_op.value.before, write_op.value.after);
} else {
assert_eq!(write_op.value.after, expected[mem_round_id]);
}
}
}

Expand Down
12 changes: 7 additions & 5 deletions ceno_recursion/src/zkvm_verifier/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1360,12 +1360,14 @@ pub fn evaluate_selector<C: Config>(

(expr, sel)
}
SelectorType::OrderedSparse32 {
SelectorType::OrderedSparse {
num_vars,
indices,
expression,
} => {
let out_point_slice = out_point.slice(builder, 0, 5);
let in_point_slice = in_point.slice(builder, 0, 5);
let num_vars = *num_vars;
let out_point_slice = out_point.slice(builder, 0, num_vars);
let in_point_slice = in_point.slice(builder, 0, num_vars);
let out_subgroup_eq = build_eq_x_r_vec_sequential(builder, &out_point_slice);
let in_subgroup_eq = build_eq_x_r_vec_sequential(builder, &in_point_slice);

Expand All @@ -1376,8 +1378,8 @@ pub fn evaluate_selector<C: Config>(
builder.assign(&eval, eval + out_val * in_val);
}

let out_point_slice = out_point.slice(builder, 5, out_point.len());
let in_point_slice = in_point.slice(builder, 5, in_point.len());
let out_point_slice = out_point.slice(builder, num_vars, out_point.len());
let in_point_slice = in_point.slice(builder, num_vars, in_point.len());
let n_bits = builder.get(&ctx.num_instances_bit_decomps, 0);

let sel =
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ tracing.workspace = true
tracing-forest.workspace = true
tracing-subscriber.workspace = true

arrayref = "0.3.9"
bincode.workspace = true
cfg-if.workspace = true
clap.workspace = true
Expand Down
Loading