Skip to content

Commit

Permalink
RadixN: Like Radix4, but supports size 2,3,4,5,6, and 7 cross-FFTs al…
Browse files Browse the repository at this point in the history
…l in the same instance (#132)

* Implemented RadixN: Like radix4 but for mixed cross-FFTs

* Remove stale code

* remove explicit enum discriminants to unblock the build

* Re-added function that appeared unused, and re-added default features

* Removed another non-1.61 feature

* RadixN Cleanup
  • Loading branch information
ejmahler authored Feb 7, 2024
1 parent 4c1dda2 commit 01fa5c8
Show file tree
Hide file tree
Showing 12 changed files with 921 additions and 219 deletions.
7 changes: 7 additions & 0 deletions src/algorithm/dft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ impl<T: FftNum> Dft<T> {
}
}

fn inplace_scratch_len(&self) -> usize {
self.len()
}
fn outofplace_scratch_len(&self) -> usize {
0
}

fn perform_fft_out_of_place(
&self,
signal: &[Complex<T>],
Expand Down
2 changes: 2 additions & 0 deletions src/algorithm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod mixed_radix;
mod raders_algorithm;
mod radix3;
mod radix4;
mod radixn;

/// Hardcoded size-specfic FFT algorithms
pub mod butterflies;
Expand All @@ -16,3 +17,4 @@ pub use self::mixed_radix::{MixedRadix, MixedRadixSmall};
pub use self::raders_algorithm::RadersAlgorithm;
pub use self::radix3::Radix3;
pub use self::radix4::Radix4;
pub use self::radixn::RadixN;
89 changes: 40 additions & 49 deletions src/algorithm/radix3.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::sync::Arc;

use num_complex::Complex;
use num_traits::Zero;

use crate::algorithm::butterflies::{Butterfly1, Butterfly27, Butterfly3, Butterfly9};
use crate::algorithm::radixn::butterfly_3;
use crate::array_utils::{self, bitreversed_transpose, compute_logarithm};
use crate::common::{fft_error_inplace, fft_error_outofplace};
use crate::{common::FftNum, twiddles, FftDirection};
Expand Down Expand Up @@ -32,6 +32,8 @@ pub struct Radix3<T> {

len: usize,
direction: FftDirection,
inplace_scratch_len: usize,
outofplace_scratch_len: usize,
}

impl<T: FftNum> Radix3<T> {
Expand Down Expand Up @@ -68,20 +70,32 @@ impl<T: FftNum> Radix3<T> {
// but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down
// so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up
const ROW_COUNT: usize = 3;
let mut cross_fft_len = base_len * ROW_COUNT;
let mut cross_fft_len = base_len;
let mut twiddle_factors = Vec::with_capacity(len * 2);
while cross_fft_len <= len {
let num_columns = cross_fft_len / ROW_COUNT;
while cross_fft_len < len {
let num_columns = cross_fft_len;
cross_fft_len *= ROW_COUNT;

for i in 0..num_columns {
for k in 1..ROW_COUNT {
let twiddle = twiddles::compute_twiddle(i * k, cross_fft_len, direction);
twiddle_factors.push(twiddle);
}
}
cross_fft_len *= ROW_COUNT;
}

let base_inplace_scratch = base_fft.get_inplace_scratch_len();
let inplace_scratch_len = if base_inplace_scratch > cross_fft_len {
cross_fft_len + base_inplace_scratch
} else {
cross_fft_len
};
let outofplace_scratch_len = if base_inplace_scratch > len {
base_inplace_scratch
} else {
0
};

Self {
twiddles: twiddle_factors.into_boxed_slice(),
butterfly3: Butterfly3::new(direction),
Expand All @@ -91,14 +105,24 @@ impl<T: FftNum> Radix3<T> {

len,
direction,

inplace_scratch_len,
outofplace_scratch_len,
}
}

fn inplace_scratch_len(&self) -> usize {
self.inplace_scratch_len
}
fn outofplace_scratch_len(&self) -> usize {
self.outofplace_scratch_len
}

fn perform_fft_out_of_place(
&self,
input: &[Complex<T>],
input: &mut [Complex<T>],
output: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
scratch: &mut [Complex<T>],
) {
// copy the data into the output vector
if self.len() == self.base_len {
Expand All @@ -108,63 +132,30 @@ impl<T: FftNum> Radix3<T> {
}

// Base-level FFTs
self.base_fft.process_with_scratch(output, &mut []);
let base_scratch = if scratch.len() > 0 { scratch } else { input };
self.base_fft.process_with_scratch(output, base_scratch);

// cross-FFTs
const ROW_COUNT: usize = 3;
let mut cross_fft_len = self.base_len * ROW_COUNT;
let mut cross_fft_len = self.base_len;
let mut layer_twiddles: &[Complex<T>] = &self.twiddles;

while cross_fft_len <= input.len() {
let num_rows = input.len() / cross_fft_len;
let num_columns = cross_fft_len / ROW_COUNT;

for i in 0..num_rows {
unsafe {
butterfly_3(
&mut output[i * cross_fft_len..],
layer_twiddles,
num_columns,
&self.butterfly3,
)
}
while cross_fft_len < output.len() {
let num_columns = cross_fft_len;
cross_fft_len *= ROW_COUNT;

for data in output.chunks_exact_mut(cross_fft_len) {
unsafe { butterfly_3(data, layer_twiddles, num_columns, &self.butterfly3) }
}

// skip past all the twiddle factors used in this layer
let twiddle_offset = num_columns * (ROW_COUNT - 1);
layer_twiddles = &layer_twiddles[twiddle_offset..];

cross_fft_len *= ROW_COUNT;
}
}
}
boilerplate_fft_oop!(Radix3, |this: &Radix3<_>| this.len);

unsafe fn butterfly_3<T: FftNum>(
data: &mut [Complex<T>],
twiddles: &[Complex<T>],
num_ffts: usize,
butterfly3: &Butterfly3<T>,
) {
let mut idx = 0usize;
let mut tw_idx = 0usize;
let mut scratch = [Zero::zero(); 3];
for _ in 0..num_ffts {
scratch[0] = *data.get_unchecked(idx);
scratch[1] = *data.get_unchecked(idx + 1 * num_ffts) * twiddles[tw_idx];
scratch[2] = *data.get_unchecked(idx + 2 * num_ffts) * twiddles[tw_idx + 1];

butterfly3.perform_fft_butterfly(&mut scratch);

*data.get_unchecked_mut(idx) = scratch[0];
*data.get_unchecked_mut(idx + 1 * num_ffts) = scratch[1];
*data.get_unchecked_mut(idx + 2 * num_ffts) = scratch[2];

tw_idx += 2;
idx += 1;
}
}

#[cfg(test)]
mod unit_tests {
use super::*;
Expand Down
95 changes: 42 additions & 53 deletions src/algorithm/radix4.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::sync::Arc;

use num_complex::Complex;
use num_traits::Zero;

use crate::algorithm::butterflies::{
Butterfly1, Butterfly16, Butterfly2, Butterfly32, Butterfly4, Butterfly8,
};
use crate::algorithm::radixn::butterfly_4;
use crate::array_utils::{self, bitreversed_transpose};
use crate::common::{fft_error_inplace, fft_error_outofplace};
use crate::{common::FftNum, twiddles, FftDirection};
Expand Down Expand Up @@ -33,6 +33,8 @@ pub struct Radix4<T> {

len: usize,
direction: FftDirection,
inplace_scratch_len: usize,
outofplace_scratch_len: usize,
}

impl<T: FftNum> Radix4<T> {
Expand Down Expand Up @@ -75,20 +77,32 @@ impl<T: FftNum> Radix4<T> {
// but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down
// so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up
const ROW_COUNT: usize = 4;
let mut cross_fft_len = base_len * ROW_COUNT;
let mut cross_fft_len = base_len;
let mut twiddle_factors = Vec::with_capacity(len * 2);
while cross_fft_len <= len {
let num_columns = cross_fft_len / ROW_COUNT;
while cross_fft_len < len {
let num_columns = cross_fft_len;
cross_fft_len *= ROW_COUNT;

for i in 0..num_columns {
for k in 1..ROW_COUNT {
let twiddle = twiddles::compute_twiddle(i * k, cross_fft_len, direction);
twiddle_factors.push(twiddle);
}
}
cross_fft_len *= ROW_COUNT;
}

let base_inplace_scratch = base_fft.get_inplace_scratch_len();
let inplace_scratch_len = if base_inplace_scratch > cross_fft_len {
cross_fft_len + base_inplace_scratch
} else {
cross_fft_len
};
let outofplace_scratch_len = if base_inplace_scratch > len {
base_inplace_scratch
} else {
0
};

Self {
twiddles: twiddle_factors.into_boxed_slice(),

Expand All @@ -97,14 +111,24 @@ impl<T: FftNum> Radix4<T> {

len,
direction,

inplace_scratch_len,
outofplace_scratch_len,
}
}

fn inplace_scratch_len(&self) -> usize {
self.inplace_scratch_len
}
fn outofplace_scratch_len(&self) -> usize {
self.outofplace_scratch_len
}

fn perform_fft_out_of_place(
&self,
input: &[Complex<T>],
input: &mut [Complex<T>],
output: &mut [Complex<T>],
_scratch: &mut [Complex<T>],
scratch: &mut [Complex<T>],
) {
// copy the data into the output vector
if self.len() == self.base_len {
Expand All @@ -114,67 +138,32 @@ impl<T: FftNum> Radix4<T> {
}

// Base-level FFTs
self.base_fft.process_with_scratch(output, &mut []);
let base_scratch = if scratch.len() > 0 { scratch } else { input };
self.base_fft.process_with_scratch(output, base_scratch);

// cross-FFTs
const ROW_COUNT: usize = 4;
let mut cross_fft_len = self.base_len * ROW_COUNT;
let mut cross_fft_len = self.base_len;
let mut layer_twiddles: &[Complex<T>] = &self.twiddles;

while cross_fft_len <= input.len() {
let num_rows = input.len() / cross_fft_len;
let num_columns = cross_fft_len / ROW_COUNT;

for i in 0..num_rows {
unsafe {
butterfly_4(
&mut output[i * cross_fft_len..],
layer_twiddles,
num_columns,
self.direction,
)
}
let butterfly4 = Butterfly4::new(self.direction);

while cross_fft_len < output.len() {
let num_columns = cross_fft_len;
cross_fft_len *= ROW_COUNT;

for data in output.chunks_exact_mut(cross_fft_len) {
unsafe { butterfly_4(data, layer_twiddles, num_columns, &butterfly4) }
}

// skip past all the twiddle factors used in this layer
let twiddle_offset = num_columns * (ROW_COUNT - 1);
layer_twiddles = &layer_twiddles[twiddle_offset..];

cross_fft_len *= ROW_COUNT;
}
}
}
boilerplate_fft_oop!(Radix4, |this: &Radix4<_>| this.len);

unsafe fn butterfly_4<T: FftNum>(
data: &mut [Complex<T>],
twiddles: &[Complex<T>],
num_ffts: usize,
direction: FftDirection,
) {
let butterfly4 = Butterfly4::new(direction);

let mut idx = 0usize;
let mut tw_idx = 0usize;
let mut scratch = [Zero::zero(); 4];
for _ in 0..num_ffts {
scratch[0] = *data.get_unchecked(idx);
scratch[1] = *data.get_unchecked(idx + 1 * num_ffts) * twiddles[tw_idx];
scratch[2] = *data.get_unchecked(idx + 2 * num_ffts) * twiddles[tw_idx + 1];
scratch[3] = *data.get_unchecked(idx + 3 * num_ffts) * twiddles[tw_idx + 2];

butterfly4.perform_fft_butterfly(&mut scratch);

*data.get_unchecked_mut(idx) = scratch[0];
*data.get_unchecked_mut(idx + 1 * num_ffts) = scratch[1];
*data.get_unchecked_mut(idx + 2 * num_ffts) = scratch[2];
*data.get_unchecked_mut(idx + 3 * num_ffts) = scratch[3];

tw_idx += 3;
idx += 1;
}
}

#[cfg(test)]
mod unit_tests {
use super::*;
Expand Down
Loading

0 comments on commit 01fa5c8

Please sign in to comment.