Skip to content

Commit

Permalink
feat: prover memory optimizations (#294)
Browse files Browse the repository at this point in the history
Co-authored-by: John Guibas <jtguibas@Johns-MBP.monkeybrains.net>
Co-authored-by: John Guibas <jtguibas@Johns-MacBook-Pro.local>
  • Loading branch information
3 people authored Feb 23, 2024
1 parent aa71d90 commit ff14294
Show file tree
Hide file tree
Showing 31 changed files with 886 additions and 81 deletions.
44 changes: 44 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ anyhow = "1.0.79"
serial_test = "3.0.0"
petgraph = "0.6.4"
tiny-keccak = { version = "2.0.2", features = ["keccak"] }
hashbrown = "0.14.3"
num_cpus = "1.16.0"

[dev-dependencies]
criterion = "0.5.1"
Expand Down
5 changes: 5 additions & 0 deletions core/src/air/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ pub trait MachineAir<F: Field>: BaseAir<F> {
output: &mut ExecutionRecord,
) -> RowMajorMatrix<F>;

/// Generate the dependencies for a given execution record.
fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) {
self.generate_trace(input, output);
}

/// The number of preprocessed columns in the trace.
fn preprocessed_width(&self) -> usize {
0
Expand Down
36 changes: 27 additions & 9 deletions core/src/alu/add/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use p3_air::{Air, BaseAir};
use p3_field::PrimeField;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::MatrixRowSlices;
use p3_maybe_rayon::prelude::ParallelIterator;
use p3_maybe_rayon::prelude::ParallelSlice;
use sp1_derive::AlignedBorrow;
use tracing::instrument;

Expand Down Expand Up @@ -49,16 +51,32 @@ impl<F: PrimeField> MachineAir<F> for AddChip {
output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
// Generate the rows for the trace.
let chunk_size = std::cmp::max(input.add_events.len() / num_cpus::get(), 1);
let rows_and_records = input
.add_events
.par_chunks(chunk_size)
.map(|events| {
let mut record = ExecutionRecord::default();
let rows = events
.iter()
.map(|event| {
let mut row = [F::zero(); NUM_ADD_COLS];
let cols: &mut AddCols<F> = row.as_mut_slice().borrow_mut();
cols.add_operation.populate(&mut record, event.b, event.c);
cols.b = Word::from(event.b);
cols.c = Word::from(event.c);
cols.is_real = F::one();
row
})
.collect::<Vec<_>>();
(rows, record)
})
.collect::<Vec<_>>();

let mut rows: Vec<[F; NUM_ADD_COLS]> = vec![];
for i in 0..input.add_events.len() {
let mut row = [F::zero(); NUM_ADD_COLS];
let cols: &mut AddCols<F> = row.as_mut_slice().borrow_mut();
let event = input.add_events[i];
cols.add_operation.populate(output, event.b, event.c);
cols.b = Word::from(event.b);
cols.c = Word::from(event.c);
cols.is_real = F::one();
rows.push(row);
for mut row_and_record in rows_and_records {
rows.extend(row_and_record.0);
output.append(&mut row_and_record.1);
}

// Convert the trace to a row major matrix.
Expand Down
45 changes: 44 additions & 1 deletion core/src/cpu/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ use crate::disassembler::WORD_SIZE;
use crate::field::event::FieldEvent;
use crate::memory::MemoryCols;
use crate::runtime::{ExecutionRecord, Opcode};
use hashbrown::HashMap;
use p3_field::PrimeField;
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::prelude::IntoParallelRefIterator;
use p3_maybe_rayon::prelude::ParallelIterator;
use p3_maybe_rayon::prelude::ParallelSlice;
use std::borrow::BorrowMut;
use std::collections::HashMap;
use tracing::instrument;

impl<F: PrimeField> MachineAir<F> for CpuChip {
Expand Down Expand Up @@ -71,6 +72,48 @@ impl<F: PrimeField> MachineAir<F> for CpuChip {

trace
}

#[instrument(name = "generate CPU dependencies", skip_all)]
fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) {
let mut new_alu_events = HashMap::with_capacity(input.cpu_events.len());
let mut new_blu_events = Vec::with_capacity(input.cpu_events.len());
let mut new_field_events: Vec<FieldEvent> = Vec::with_capacity(input.cpu_events.len());

// Generate the trace rows for each event.
let chunk_size = std::cmp::max(input.cpu_events.len() / num_cpus::get(), 1);
let events = input
.cpu_events
.par_chunks(chunk_size)
.map(|ops: &[CpuEvent]| {
ops.iter()
.map(|op| {
let (_, alu_events, blu_events, field_events) = self.event_to_row::<F>(*op);
(alu_events, blu_events, field_events)
})
.collect::<Vec<_>>()
})
.flatten()
.collect::<Vec<_>>();

events.into_iter().for_each(|e| {
let (alu_events, blu_events, field_events) = e;
for (key, value) in alu_events {
new_alu_events
.entry(key)
.and_modify(|op_new_events: &mut Vec<AluEvent>| {
op_new_events.extend(value.clone())
})
.or_insert(value);
}
new_blu_events.extend(blu_events);
new_field_events.extend(field_events);
});

// Add the dependency events to the shard.
output.add_alu_events(new_alu_events);
output.add_byte_lookup_events(new_blu_events);
output.add_field_events(&new_field_events);
}
}

impl CpuChip {
Expand Down
8 changes: 8 additions & 0 deletions core/src/disassembler/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,19 @@ impl Instruction {
}

/// Returns if the instruction is an R-type instruction.
#[inline(always)]
pub fn is_r_type(&self) -> bool {
!self.imm_c
}

/// Returns whether the instruction is an I-type instruction.
#[inline(always)]
pub fn is_i_type(&self) -> bool {
self.imm_c
}

/// Decode the instruction in the R-type format.
#[inline(always)]
pub fn r_type(&self) -> (Register, Register, Register) {
(
Register::from_u32(self.op_a),
Expand All @@ -91,6 +94,7 @@ impl Instruction {
}

/// Decode the instruction in the I-type format.
#[inline(always)]
pub fn i_type(&self) -> (Register, Register, u32) {
(
Register::from_u32(self.op_a),
Expand All @@ -100,6 +104,7 @@ impl Instruction {
}

/// Decode the instruction in the S-type format.
#[inline(always)]
pub fn s_type(&self) -> (Register, Register, u32) {
(
Register::from_u32(self.op_a),
Expand All @@ -109,6 +114,7 @@ impl Instruction {
}

/// Decode the instruction in the B-type format.
#[inline(always)]
pub fn b_type(&self) -> (Register, Register, u32) {
(
Register::from_u32(self.op_a),
Expand All @@ -118,11 +124,13 @@ impl Instruction {
}

/// Decode the instruction in the J-type format.
#[inline(always)]
pub fn j_type(&self) -> (Register, u32) {
(Register::from_u32(self.op_a), self.op_b)
}

/// Decode the instruction in the U-type format.
#[inline(always)]
pub fn u_type(&self) -> (Register, u32) {
(Register::from_u32(self.op_a), self.op_b)
}
Expand Down
4 changes: 3 additions & 1 deletion core/src/field/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::{AbstractField, Field, PrimeField};
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::MatrixRowSlices;
use p3_maybe_rayon::prelude::IntoParallelRefIterator;
use p3_maybe_rayon::prelude::ParallelIterator;
use sp1_derive::AlignedBorrow;

use crate::air::FieldAirBuilder;
Expand Down Expand Up @@ -59,7 +61,7 @@ impl<F: PrimeField> MachineAir<F> for FieldLTUChip {
// Generate the trace rows for each event.
let rows = input
.field_events
.iter()
.par_iter()
.map(|event| {
let mut row = [F::zero(); NUM_FIELD_COLS];
let cols: &mut FieldLTUCols<F> = row.as_mut_slice().borrow_mut();
Expand Down
10 changes: 6 additions & 4 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ impl SP1Prover {
runtime.run();
});
let config = BabyBearBlake3::new();
let proof = prove_core(config, &mut runtime);
let stdout = SP1Stdout::from(&runtime.state.output_stream);
let proof = prove_core(config, runtime);
Ok(SP1ProofWithIO {
proof,
stdin,
stdout: SP1Stdout::from(&runtime.state.output_stream),
stdout,
})
}

Expand All @@ -103,11 +104,12 @@ impl SP1Prover {
let mut runtime = Runtime::new(program);
runtime.write_stdin_slice(&stdin.buffer.data);
runtime.run();
let proof = prove_core(config, &mut runtime);
let stdout = SP1Stdout::from(&runtime.state.output_stream);
let proof = prove_core(config, runtime);
Ok(SP1ProofWithIO {
proof,
stdin,
stdout: SP1Stdout::from(&runtime.state.output_stream),
stdout,
})
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/runtime/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,6 @@ pub mod tests {
runtime.write_stdin(&points.1);
runtime.run();
let config = BabyBearBlake3::new();
prove_core(config, &mut runtime);
prove_core(config, runtime);
}
}
Loading

0 comments on commit ff14294

Please sign in to comment.