Skip to content

Commit

Permalink
PolyOps::extend
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfelder committed Jan 30, 2025
1 parent 36f051a commit ab836c1
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 88 deletions.
1 change: 1 addition & 0 deletions crates/prover/src/core/backend/icicle/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ mod tests {
use crate::core::backend::{Column, CpuBackend};
use crate::core::{air::accumulation::AccumulationOps, backend::icicle::IcicleBackend, fields::{qm31::SecureField, secure_column::SecureColumnByCoords}};
use num_traits::{Zero, One};
use crate::core::fields::m31::BaseField;

#[cfg(feature = "icicle")]
#[test]
Expand Down
209 changes: 121 additions & 88 deletions crates/prover/src/core/backend/icicle/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,114 +40,118 @@ impl PolyOps for IcicleBackend {
eval: CircleEvaluation<Self, BaseField, BitReversedOrder>,
itwiddles: &TwiddleTree<Self>,
) -> CirclePoly<Self> {
// todo!()
if eval.domain.log_size() <= 3 || eval.domain.log_size() == 7 {
// TODO: as property .is_dcct_available etc...
return unsafe {
transmute(CpuBackend::interpolate(
transmute(eval),
transmute(itwiddles),
))
};
}
todo!()
// if eval.domain.log_size() <= 3 || eval.domain.log_size() == 7 {
// // TODO: as property .is_dcct_available etc...
// return unsafe {
// transmute(CpuBackend::interpolate(
// transmute(eval),
// transmute(itwiddles),
// ))
// };
// }

let values = eval.values;
nvtx::range_push!("[ICICLE] get_dcct_root_of_unity");
let rou = get_dcct_root_of_unity(eval.domain.size() as _);
nvtx::range_pop!();
// let values = eval.values;
// nvtx::range_push!("[ICICLE] get_dcct_root_of_unity");
// let rou = get_dcct_root_of_unity(eval.domain.size() as _);
// nvtx::range_pop!();

nvtx::range_push!("[ICICLE] initialize_dcct_domain");
initialize_dcct_domain(eval.domain.log_size(), rou, &DeviceContext::default()).unwrap();
nvtx::range_pop!();
// nvtx::range_push!("[ICICLE] initialize_dcct_domain");
// initialize_dcct_domain(eval.domain.log_size(), rou, &DeviceContext::default()).unwrap();
// nvtx::range_pop!();

let mut evaluations = vec![ScalarField::zero(); values.len()];
let values: Vec<ScalarField> = unsafe { transmute(values) };
let mut cfg = NTTConfig::default();
cfg.ordering = Ordering::kMN;
nvtx::range_push!("[ICICLE] interpolate");
dcct::interpolate(
HostSlice::from_slice(&values),
&cfg,
HostSlice::from_mut_slice(&mut evaluations),
)
.unwrap();
nvtx::range_pop!();
// let mut evaluations = vec![ScalarField::zero(); values.len()];
// let values: Vec<ScalarField> = unsafe { transmute(values) };
// let mut cfg = NTTConfig::default();
// cfg.ordering = Ordering::kMN;
// nvtx::range_push!("[ICICLE] interpolate");
// dcct::interpolate(
// HostSlice::from_slice(&values),
// &cfg,
// HostSlice::from_mut_slice(&mut evaluations),
// )
// .unwrap();
// nvtx::range_pop!();

let values: Vec<BaseField> = unsafe { transmute(evaluations) };
// let values: Vec<BaseField> = unsafe { transmute(evaluations) };

CirclePoly::new(DeviceColumn::from_cpu(&values))
// CirclePoly::new(DeviceColumn::from_cpu(&values))
}

fn eval_at_point(poly: &CirclePoly<Self>, point: CirclePoint<SecureField>) -> SecureField {
// todo!()
todo!()
// unsafe { CpuBackend::eval_at_point(transmute(poly), point) }
if poly.log_size() == 0 {
return poly.coeffs.to_cpu()[0].into();
}
// TODO: to gpu after correctness fix
nvtx::range_push!("[ICICLE] create mappings");
let mut mappings = vec![point.y];
let mut x = point.x;
for _ in 1..poly.log_size() {
mappings.push(x);
x = CirclePoint::double_x(x);
}
mappings.reverse();
nvtx::range_pop!();
// if poly.log_size() == 0 {
// return poly.coeffs.to_cpu()[0].into();
// }
// // TODO: to gpu after correctness fix
// nvtx::range_push!("[ICICLE] create mappings");
// let mut mappings = vec![point.y];
// let mut x = point.x;
// for _ in 1..poly.log_size() {
// mappings.push(x);
// x = CirclePoint::double_x(x);
// }
// mappings.reverse();
// nvtx::range_pop!();

nvtx::range_push!("[ICICLE] fold");
let folded = crate::core::backend::icicle::utils::fold(&poly.coeffs.to_cpu(), &mappings);
nvtx::range_pop!();
folded
// nvtx::range_push!("[ICICLE] fold");
// let folded = crate::core::backend::icicle::utils::fold(&poly.coeffs.to_cpu(), &mappings);
// nvtx::range_pop!();
// folded
}

fn extend(poly: &CirclePoly<Self>, log_size: u32) -> CirclePoly<Self> {
// todo!()
unsafe { transmute(CpuBackend::extend(transmute(poly), log_size)) }
assert!(log_size >= poly.log_size());

let mut device_column = DeviceColumn::zeros(1 << log_size);
poly.coeffs.data.copy_to_device(&mut device_column.data).unwrap();

CirclePoly::new(device_column)
}

fn evaluate(
poly: &CirclePoly<Self>,
domain: CircleDomain,
twiddles: &TwiddleTree<Self>,
) -> CircleEvaluation<Self, BaseField, BitReversedOrder> {
// todo!()
if domain.log_size() <= 3 || domain.log_size() == 7 {
return unsafe {
transmute(CpuBackend::evaluate(
transmute(poly),
domain,
transmute(twiddles),
))
};
}
todo!()
// if domain.log_size() <= 3 || domain.log_size() == 7 {
// return unsafe {
// transmute(CpuBackend::evaluate(
// transmute(poly),
// domain,
// transmute(twiddles),
// ))
// };
// }

let values = poly.extend(domain.log_size()).coeffs;
nvtx::range_push!("[ICICLE] get_dcct_root_of_unity");
let rou = get_dcct_root_of_unity(domain.size() as _);
nvtx::range_pop!();
nvtx::range_push!("[ICICLE] initialize_dcct_domain");
initialize_dcct_domain(domain.log_size(), rou, &DeviceContext::default()).unwrap();
nvtx::range_pop!();
// let values = poly.extend(domain.log_size()).coeffs;
// nvtx::range_push!("[ICICLE] get_dcct_root_of_unity");
// let rou = get_dcct_root_of_unity(domain.size() as _);
// nvtx::range_pop!();
// nvtx::range_push!("[ICICLE] initialize_dcct_domain");
// initialize_dcct_domain(domain.log_size(), rou, &DeviceContext::default()).unwrap();
// nvtx::range_pop!();

let mut evaluations = vec![ScalarField::zero(); values.len()];
let values: Vec<ScalarField> = unsafe { transmute(values) };
let mut cfg = NTTConfig::default();
cfg.ordering = Ordering::kNM;
nvtx::range_push!("[ICICLE] evaluate");
dcct::evaluate(
HostSlice::from_slice(&values),
&cfg,
HostSlice::from_mut_slice(&mut evaluations),
)
.unwrap();
nvtx::range_pop!();
unsafe {
transmute(IcicleCircleEvaluation::<BaseField, BitReversedOrder>::new(
domain,
transmute(evaluations),
))
}
// let mut evaluations = vec![ScalarField::zero(); values.len()];
// let values: Vec<ScalarField> = unsafe { transmute(values) };
// let mut cfg = NTTConfig::default();
// cfg.ordering = Ordering::kNM;
// nvtx::range_push!("[ICICLE] evaluate");
// dcct::evaluate(
// HostSlice::from_slice(&values),
// &cfg,
// HostSlice::from_mut_slice(&mut evaluations),
// )
// .unwrap();
// nvtx::range_pop!();
// unsafe {
// transmute(IcicleCircleEvaluation::<BaseField, BitReversedOrder>::new(
// domain,
// transmute(evaluations),
// ))
// }
}

fn interpolate_columns(
Expand Down Expand Up @@ -262,7 +266,36 @@ impl PolyOps for IcicleBackend {
}

fn precompute_twiddles(coset: Coset) -> TwiddleTree<Self> {
// todo!()
unsafe { transmute(CpuBackend::precompute_twiddles(coset)) }
todo!()
// unsafe { transmute(CpuBackend::precompute_twiddles(coset)) }
}
}


#[cfg(test)]
mod tests {
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::poly::circle::{CanonicCoset, PolyOps};
use super::IcicleCirclePoly;

// #[cfg(feature = "icicle")]
#[test]
fn test_extend() {
use num_traits::Zero;

use crate::core::backend::{icicle::column::DeviceColumn, Column};

let cpu_col = (1..=8).map(BaseField::from).collect::<Vec<BaseField>>();
let device_col = DeviceColumn::from_cpu(&cpu_col);
let poly = IcicleCirclePoly::new(device_col);
let vals = poly.extend(4);

let vals_on_cpu = vals.coeffs.to_cpu().to_vec();
let (first_8_vals, rest) = vals_on_cpu.split_at(8);
let poly_on_cpu = poly.coeffs.to_cpu().to_vec();

assert_eq!(poly_on_cpu, first_8_vals);
assert!(rest.iter().all(|&item| item == BaseField::zero()));
}
}

0 comments on commit ab836c1

Please sign in to comment.