Skip to content

Commit

Permalink
Remove unnecessary data transfers
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfelder committed Jan 30, 2025
1 parent a4afa10 commit c92dac7
Showing 2 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion crates/prover/src/core/backend/icicle/circle.rs
Original file line number Diff line number Diff line change
@@ -100,7 +100,7 @@ impl PolyOps for IcicleBackend {
nvtx::range_pop!();

nvtx::range_push!("[ICICLE] fold");
let folded = crate::core::backend::icicle::utils::fold(&poly.coeffs.to_cpu(), &mappings);
let folded = crate::core::backend::icicle::utils::fold::<BaseField, SecureField>(&poly.coeffs, &mappings);
nvtx::range_pop!();
folded
}
19 changes: 13 additions & 6 deletions crates/prover/src/core/backend/icicle/utils.rs
Original file line number Diff line number Diff line change
@@ -2,13 +2,17 @@ use std::mem::transmute;

use icicle_core::ntt::FieldImpl;
use icicle_core::vec_ops::{fold_scalars, VecOps, VecOpsConfig};
use icicle_cuda_runtime::memory::HostSlice;
use icicle_cuda_runtime::memory::{DeviceSlice, HostSlice};
use icicle_m31::field::{ComplexExtensionField, QuarticExtensionField, ScalarField};

use crate::core::fields::m31::M31;
use crate::core::backend::Col;
use crate::core::backend::icicle::IcicleBackend;
use crate::core::fields::m31::{BaseField, M31};
use crate::core::fields::qm31::QM31;
use crate::core::fields::{ExtensionOf, Field};

use super::column::DeviceColumn;

macro_rules! select_result_type {
(1) => {
ScalarField
@@ -44,12 +48,12 @@ macro_rules! select_result_type {
/// factors is provided.
// TODO(Andrew): Can be made to run >10x faster by unrolling lower layers of recursion
pub fn fold<'a, F: Field, E: ExtensionOf<F> + Sized>(
values: &'a [F],
values: &'a DeviceColumn,
folding_factors: &'a [E],
) -> E {
assert!(values.len().is_power_of_two());

let a = HostSlice::from_slice(unsafe { transmute(values) });
let a = unsafe { transmute::<&DeviceSlice<BaseField>, &DeviceSlice<ScalarField>>(&values.data) };
let b = HostSlice::from_slice(unsafe { transmute(folding_factors) });
let mut result = vec![QuarticExtensionField::zero()];
let res = HostSlice::from_mut_slice(&mut result);
@@ -97,7 +101,9 @@ mod tests {
QM31::from_u32_unchecked(3, 0, 0, 0),
QM31::from_u32_unchecked(4, 0, 0, 0),
];
let result = fold(&values, &folding_factors);

let values_device = DeviceColumn::from_cpu(&values);
let result = fold::<M31, QM31>(&values_device, &folding_factors);

let expected = QM31::from_u32_unchecked(358, 0, 0, 0);
assert_eq!(result, expected, "Result for simple folding is incorrect");
@@ -125,13 +131,14 @@ mod tests {
.map(|i| M31(i as u32))
.collect();

let values_device = DeviceColumn::from_cpu(&values);
// Initialize the `folding_factors` vector
let mut folding_factors = Vec::with_capacity(folding_factors_length);
for i in 2..(2 + folding_factors_length) {
folding_factors.push(QM31::from_u32_unchecked(i as u32, 0, 0, 0));
}
let time = std::time::Instant::now();
let result = fold(&values, &folding_factors);
let result = fold::<M31, QM31>(&values_device, &folding_factors);
let elapsed = time.elapsed();
println!(
"Elapsed time for 2^{}: {:?}",

0 comments on commit c92dac7

Please sign in to comment.